diff --git a/MindIE/MindIE-Torch/built-in/audio/whisperX/compile_whisper.py b/MindIE/MindIE-Torch/built-in/audio/whisperX/compile_whisper.py new file mode 100644 index 0000000000000000000000000000000000000000..3f4bfd3081408211830fdea11c556a00a74691f9 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/audio/whisperX/compile_whisper.py @@ -0,0 +1,177 @@ +import os +import argparse +import time +import torch +import mindietorch +from mindietorch._enums import dtype +from modeling_whisper import MindieWhisperForConditionalGeneration +from config import CompileInfo + +def compile_encoder(model : MindieWhisperForConditionalGeneration, + args : argparse, + compile_info : CompileInfo): + encoder = model.get_encoder() + + class Encoder(torch.nn.Module): + + def __init__(self, model): + super().__init__() + self.model = model + + def forward(self, input_features): + return self.model(input_features=input_features, return_dict=False) + + input_features = torch.randn([args.bs, compile_info.mel_feature_size, compile_info.max_frames]) + encoder_traced = torch.jit.trace(Encoder(encoder), (input_features)) + input_info = [mindietorch.Input(shape=(args.bs, compile_info.mel_feature_size, compile_info.max_frames))] + try: + print("Start AOE optimization_level 1.") + compiled = mindietorch.compile(encoder_traced, + inputs=input_info, + precision_policy=mindietorch.PrecisionPolicy.FP16, + soc_version=args.soc_version, + optimization_level=1 + ) + except Exception: + print("Using Aoe to optimize encoder model failed, but still can compile encoder model.") + else: + print("AOE optimize finished and start compile encoder.") + compiled = mindietorch.compile(encoder_traced, + inputs=input_info, + precision_policy=mindietorch.PrecisionPolicy.FP16, + soc_version=args.soc_version, + ) + save_file = os.path.join(args.save_path, f"{compile_info.prefix_nameefix_name[0]}{args.bs}.ts") + torch.jit.save(compiled, save_file) + print(f"Compile encoder success, saved in {save_file}") + +def compile_prefill_decoder(model : MindieWhisperForConditionalGeneration, + args : argparse, compile_info : CompileInfo): + print("Start compiling prefill_decoder.") + encoder_outputs = torch.randn([args.bs, compile_info.encoder_seq_len, compile_info.hidden_size]) + decoder_input_ids = torch.randint(1, 4, (args.bs, 1)) + prefill_decoder_traced = torch.jit.trace(model.eval(), (decoder_input_ids, encoder_outputs)) + input_info = [mindietorch.Input(shape=(args.bs, 1), dtype=dtype.INT64), + mindietorch.Input(shape=(args.bs, encoder_seq_len, hidden_size))] + prefill_decoder_compiled = mindietorch.compile(prefill_decoder_traced, + inputs=input_info, + precision_policy=mindietorch.PrecisionPolicy.FP16, + soc_version=args.soc_version) + + save_file = os.path.join(args.save_path, f"{compile_info.prefix_nameefix_name[1]}{args.bs}.ts") + torch.jit.save(prefill_decoder_compiled, save_file) + print(f"Compile prefill_decoder success, saved in {save_file}.") + +def compile_incre_decoder(args : argparse, compile_info : CompileInfo): + class Decoder(torch.nn.Module): + + def __init__(self, model): + super().__init__() + self.model = model + + def forward(self, *args): + return self.model.forward(*args)[0] + + mindie_whisper = MindieWhisperForConditionalGeneration.from_pretrained(args.model_path, + machine_type = args.machine_type, + is_incre_decode=True, + soc_version=args.soc_version) + decoder = Decoder(mindie_whisper) + print("Start compiling decoder.") + + encoder_outputs = torch.randn([args.bs, compile_info.encoder_seq_len, compile_info.hidden_size]) + decoder_input_ids = torch.randint(1, 4, (args.bs, 1)) + actual_seq_len = torch.ones((args.bs)) + all_past_key_value = [torch.randn([args.bs, compile_info.max_decode_step, compile_info.head_num, compile_info.head_size]), + torch.randn([args.bs, compile_info.seq_len, compile_info.head_num, compile_info.head_size]), + torch.randn([args.bs, compile_info.encoder_seq_len, compile_info.head_num, compile_info.head_size]), + torch.randn([args.bs, compile_info.encoder_seq_len, compile_info.head_num, compile_info.head_size]) + ] * compile_info.layer_nums + traced_args = [decoder_input_ids, encoder_outputs, actual_seq_len] + traced_args.extend(all_past_key_value) + traced_decoder = torch.jit.trace(decoder, traced_args) + # BSND + key_value_infos = [ + mindietorch.Input(shape=(args.bs, compile_info.max_len, compile_info.head_num, compile_info.head_size), + dtype=dtype.FLOAT16), + mindietorch.Input(shape=(args.bs, compile_info.max_len, compile_info.head_num, compile_info.head_size), + dtype=dtype.FLOAT16), + mindietorch.Input(shape=(args.bs, compile_info.encoder_seq_len, compile_info.head_num, compile_info.head_size), + dtype=dtype.FLOAT16), + mindietorch.Input(shape=(args.bs, compile_info.le_info.encoder_seq_len, compile_info.head_num, compile_info.head_size), + dtype=dtype.FLOAT16 + )] * compile_info.layer_nums + input_info = [mindietorch.Input(shape=(args.bs, 1), dtype=dtype.INT64), # input ids + mindietorch.Input(shape=(args.bs, compile_info.le_info.encoder_seq_len, compile_info.hidden_size)), + mindietorch.Input(shape=(args.bs,), dtype=dtype.INT64)] # actual sq len + + input_info.extend(key_value_infos) + float_size = 4 + voc_size = 51866 + buffer_size = math.ceil((args.bs * 1 * voc_size * float_size) / 1024 / 1024) + print(f"Set {buffer_size}/MB for output.") + compiled_decoder = mindietorch.compile(traced_decoder, + inputs=input_info, + precision_policy=mindietorch.PrecisionPolicy.FP16, + soc_version=args.soc_version, + default_buffer_size_vec=[buffer_size]) + save_file = os.path.join(args.save_path, f"{compile_info.prefix_nameefix_nameefix_nameefix_name[2]}{args.bs}.ts") + torch.jit.save(compiled_decoder, save_file) + + print(f"Compile whisper_decoder success, saved in {save_file}.") + +def compile_scatter_update(self, args, compile_info): + class MindieScatter(torch.nn.Module): + def forward(self, past_key_value, indices, update_states): + out = torch.ops.aie.scatter_update(past_key_value, indices, update_states, 1) + return out + + bs = args.bs + self_past_key_value = torch.randn([bs, compile_info.max_decode_step, compile_info.head_num, compile_info.head_size]) + encoder_past_key_value = torch.randn([bs, compile_info.max_decode_step, compile_info.head_num, compile_info.head_size]) + indices = torch.tensor([0] * bs) + update_states = torch.randn([bs, 1, 20, 64]) + traced = torch.jit.trace(MindieScatter(), (self_past_key_value, indices, update_states)) + + self_attn_info = mindietorch.Input(shape=self_past_key_value.shape, dtype=mindietorch.dtype.FLOAT16) + encoder_attn_info = mindietorch.Input(shape=encoder_past_key_value.shape, dtype=mindietorch.dtype.FLOAT16) + indices_info = mindietorch.Input(shape=indices.shape, dtype=mindietorch.dtype.INT64) + update_states_info = mindietorch.Input(shape=encoder_past_key_value.shape, dtype=mindietorch.dtype.FLOAT16) + + compile_self = mindietorch.compile(traced, inputs=[self_attn_info, indices_info, update_states_info], + soc_version=args.soc_version) + torch.jit.save(compile_self, f"{args.save_path}/{compile_info.prefix_name[3]}{args.bs}.ts") + + compile_self = mindietorch.compile(traced, inputs=[encoder_attn_info, indices_info, update_states_info], + soc_version=args.soc_version) + torch.jit.save(compile_self, f"{args.save_path}/{compile_info.prefix_name[4]}{args.bs}.ts") + print("compile scatter success.") + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('-model_path', type=str, required=True, help="please provide model path.") + parser.add_argument('-bs', type=int, default=8, help="please provide batch_size, default:8.") + parser.add_argument('-soc_version', type=str, required=True, + help="please provide soc_version.") + parser.add_argument('-save_path', type=str, default="compiled_models", help="compiled models save dir.") + parser.add_argument('-machine_type', type=str, choices=["300IPro", "800IA2"], default="800A2") + parser.add_argument('-device_id', type=int, default=0) + + args = parser.parse_args() + mindie_whisper = MindieWhisperForConditionalGeneration.from_pretrained(args.model_path, + soc_version=args.soc_version, + machine_type=args.machine_type) + device = f"npu:{args.device_id}" + print("Start compiling Mindie-Whisper, it will take some time, please wait.") + if not args.save_path: + raise ValueError("Please provide the directory where the compiled model saved.") + if not os.path.exists(args.save_path): + os.makedirs(args.save_path) + print(f"Directory {args.save_path} created.") + else: + print(f"Directory {args.save_path} already exists.") + + compile_scatter_update(mindie_whisper, args, CompileInfo) + compile_encoder(mindie_whisper, args, CompileInfo) + compile_prefill_decoder(mindie_whisper, args, CompileInfo) + compile_incre_decoder(args, CompileInfo) \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/audio/whisperX/modeling_whisper.py b/MindIE/MindIE-Torch/built-in/audio/whisperX/modeling_whisper.py new file mode 100644 index 0000000000000000000000000000000000000000..639c847349b2dc34b08cb9a7ef4755885d2aa3eb --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/audio/whisperX/modeling_whisper.py @@ -0,0 +1,988 @@ +import argparse +import os +import copy +import math +import warnings +import librosa +from typing import Optional, Tuple, Union, List, Dict +from collections import OrderedDict + +import torch +from torch.utils.data import Dataset +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn +from transformers import WhisperForConditionalGeneration +from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask +from transformers.models.whisper.modeling_whisper import WhisperDecoder, WhisperEncoder, WhisperEncoderLayer, \ + WhisperModel, WhisperDecoderLayer, WhisperAttention +from transformers.generation.logits_process import LogitsProcessorList +from transformers.generation.stopping_criteria import StoppingCriteriaList, validate_stopping_criteria +from transformers.generation.utils import GenerationMixin +import mindietorch +from mindietorch._enums import dtype + + +class MindiePFA(WhisperAttention): + + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: float = 0.0, + is_decoder: bool = False, + bias: bool = True, + is_causal: bool = False, + config: bool = None ): + super().__init__( + embed_dim, + num_heads, + dropout, + bias, + is_causal, + config ) + + # BSND + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + **kwargs ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + # cross_attn + is_cross_attn = key_value_states is not None + bsz, tgt_len, _ = hidden_states.size() + query_states = self.q_proj(hidden_states) + if ( + is_cross_attn + and past_key_value is not None + # past_key_value layout is BSND + and past_key_value[0].shape[1] == key_value_states.shape[1] + ): + # reuse k, v + key_states = past_key_value[0] + value_states = past_key_value[1] + elif is_cross_attn: + key_states = self._shape(self.k_proj(key_value_states), -1, bsz) + value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + # self attn + elif past_key_value is not None: + # reuse k v + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = torch.cat([past_key_value[0], key_states], dim=1) + value_states = torch.cat([past_key_value[1], value_states], dim=1) + else: + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + + if self.is_decoder: + past_key_value = (key_states.to(torch.float16), value_states.to(torch.float16)) + + # B S N D + query_states = query_states.view(bsz, tgt_len, self.num_heads, self.head_dim).contiguous() + + # B S N D + attn_output = torch.ops.aie.flash_attention( + query=query_states, + key=key_states, + value=value_states, + num_head=self.num_heads, + scale=self.scaling, + layout="BSND", + type="PFA" ) + attn_output = attn_output.view(bsz, tgt_len, self.num_heads, self.head_dim) + attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + attn_output = self.out_proj(attn_output) + return attn_output, _, past_key_value + +class MindieIFA(WhisperAttention): + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: float = 0.0, + is_decoder: bool = False, + bias: bool = True, + is_causal: bool = True, + config=None ): + super().__init__( + embed_dim, + num_heads, + dropout, + is_decoder, + bias, + is_causal, + config ) + + # BSND + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + past_key_value: torch.Tensor, + actual_seq_len: torch.Tensor, + **kwargs ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + # self attn + assert past_key_value is not None, \ + "Current operation is incre_flash_attention, past_key_value is required." + bsz, tgt_len, _ = hidden_states.size() + assert tgt_len == 1, \ + "Current operation is incre_flash_attention, query's seq length should be equal to 1." + query_states = self.q_proj(hidden_states) + key_states = self._shape(self.k_proj(hidden_states), -1, bsz).to(torch.float16) + values_states = self._shape(self.v_proj(hidden_states), -1, bsz).to(torch.float16) + past_key_cache, past_value_cache = past_key_value[0], past_key_value[1] + indices = actual_seq_len - 1 + past_key_cache = torch.ops.aie.scatter_update(past_key_cache, indices, key_states, axis=1) + past_value_cache = torch.ops.aie.scatter_update(past_value_cache, indices, values_states, axis=1) + past_key_value = (past_key_cache, past_value_cache) + # B S N D + query_states = query_states.view(bsz, tgt_len, self.num_heads, self.head_dim).contiguous() + # B S N D + attn_output = torch.ops.aie.incre_flash_attention( + query=query_states, + key=past_key_cache, + value=past_value_cache, + actual_seq_lengths=actual_seq_len, + num_head=self.num_heads, + scale=self.scaling, + layout="BSND") + attn_output = attn_output.view(bsz, tgt_len, self.num_heads, self.head_dim) + attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + attn_output = self.out_proj(attn_output) + return attn_output, _, past_key_value + +class MindieFA(WhisperAttention): + + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: float = 0.0, + is_decoder: bool = False, + bias: bool = True, + is_causal: bool = False, + config: bool = None ): + super().__init__( + embed_dim, + num_heads, + dropout, + bias, + is_causal, + config) + + # BSND + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + **kwargs ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + # cross_attn + is_cross_attn = key_value_states is not None + bsz, tgt_len, _ = hidden_states.size() + query_states = self.q_proj(hidden_states) + if ( + is_cross_attn + and past_key_value is not None + # past_key_value layout is BSND + and past_key_value[0].shape[1] == key_value_states.shape[1] + ): + # reuse k, v + key_states = past_key_value[0] + value_states = past_key_value[1] + elif is_cross_attn: + key_states = self._shape(self.k_proj(key_value_states), -1, bsz) + value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + # self attn + elif past_key_value is not None: + # reuse k v + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = torch.cat([past_key_value[0], key_states], dim=1) + value_states = torch.cat([past_key_value[1], value_states], dim=1) + else: + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + + if self.is_decoder: + past_key_value = (key_states.to(torch.float16), value_states.to(torch.float16)) + + # MindFA only support BNSD layout + query_states = query_states.view(bsz, tgt_len, self.num_heads, self.head_dim).contiguous() + attn_output = torch.ops.aie.flash_attention( + query=query_states.transpose(1, 2), + key=key_states.transpose(1, 2), + value=value_states.transpose(1, 2), + num_head=self.num_heads, + scale=self.scaling, + layout="BNSD", + type="FA_HIGH_PERF") + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.view(bsz, tgt_len, self.num_heads, self.head_dim) + attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + attn_output = self.out_proj(attn_output) + return attn_output, _, past_key_value + +class MindieAttention(WhisperAttention): + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: float = 0.0, + is_decoder: bool = False, + bias: bool = True, + is_causal: bool = False, + config=None ): + super().__init__( + embed_dim, + num_heads, + dropout, + is_decoder, + bias, + is_causal, + config ) + + # BSND + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + **kwargs ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, tgt_len, _ = hidden_states.size() + query_states = self.q_proj(hidden_states) + + # self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + + # B S N D + query_states = query_states.view(bsz, tgt_len, self.num_heads, self.head_dim).contiguous() + + # B S N D + attn_output = torch.ops.aie.flash_attention( + query=query_states, + key=key_states, + value=value_states, + num_head=self.num_heads, + scale=self.scaling, + layout="BSND", + type="PFA") + + attn_output = attn_output.view(bsz, tgt_len, self.num_heads, self.head_dim) + attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + attn_output = self.out_proj(attn_output) + return attn_output, _, _ + + +class MindieWhisperDecoderLayer(WhisperDecoderLayer): + + def __init__(self, config): + super().__init__(config) + self.embed_dim = config.d_model + if config.is_incre_decode: + self.self_attn = MindieIFA( + embed_dim=self.embed_dim, + num_heads=config.decoder_attention_heads, + config=config, + is_decoder=True + ) + self.encoder_attn = MindieIFA( + embed_dim=self.embed_dim, + num_heads=config.decoder_attention_heads, + is_decoder=True, + config=config + ) + elif config.machine_type == "800IA2": + self.self_attn = MindiePFA( + embed_dim=self.embed_dim, + num_heads=config.decoder_attention_heads, + config=config, + is_decoder=True + ) + self.encoder_attn = MindiePFA( + embed_dim=self.embed_dim, + num_heads=config.decoder_attention_heads, + is_decoder=True, + config=config + ) + elif config.machine_type == "300IPro": + self.self_attn = MindieFA( + embed_dim=self.embed_dim, + num_heads=config.decoder_attention_heads, + config=config, + is_decoder=True + ) + self.encoder_attn = MindieFA( + embed_dim=self.embed_dim, + num_heads=config.decoder_attention_heads, + is_decoder=True, + config=config + ) + else: + raise ValueError(f"Unsupporting current parameters. eg. " + f"is_incre_decode: {is_incre_decode} machine_type {machine_type}") + + def forward( + self, + hidden_states, + attention_mask, + encoder_hidden_states, + layer_head_mask, + cross_attn_layer_head_mask, + past_key_value, + cross_attn_past_key_value, + output_attentions, + use_cache, + actual_seq_len=None, + encoder_attention_mask=None + ) -> torch.Tensor: + residual = hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + + # Self Attention + self_attn_past_key_value = past_key_value if past_key_value is not None else None + hidden_states, _, present_key_value = self.self_attn( + hidden_states=hidden_states, + past_key_value=self_attn_past_key_value, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + output_attentions=output_attentions, + actual_seq_len=actual_seq_len + ) + hidden_states = residual + hidden_states + + # Cross-Attention Block + cross_attn_present_key_value = None + if encoder_hidden_states is not None: + residual = hidden_states + hidden_states = self.encoder_attn_layer_norm(hidden_states) + + hidden_states, _, cross_attn_present_key_value = self.encoder_attn( + hidden_states=hidden_states, + key_value_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + layer_head_mask=cross_attn_layer_head_mask, + past_key_value=cross_attn_past_key_value, + output_attentions=output_attentions, + ) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + + present_key_value = present_key_value + cross_attn_present_key_value + + # Fully Connected + residual = hidden_states + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = hidden_states.reshape(-1, self.embed_dim) + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = self.fc2(hidden_states) + hidden_states = hidden_states.reshape(-1, 1, self.embed_dim) + hidden_states = residual + hidden_states + outputs = (hidden_states,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +class MindieWhisperDecoder(WhisperDecoder): + + def __init__(self, config): + super().__init__(config) + self.layers = nn.ModuleList([MindieWhisperDecoderLayer(config) + for _ in range(config.decoder_layers)]) + self.config = config + + def forward( + self, + input_ids, + encoder_hidden_states, + past_key_values, + actual_seq_len, + use_cache=True, + attention_mask=None): + if input_ids is None: + raise ValueError("You have to specify either decoder_input_ids") + inputs_embeds = self.embed_tokens(input_ids) + + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + # B S N D + past_key_values_length = past_key_values[0].shape[1] if past_key_values is not None else 0 + inputs_embeds = self.embed_tokens(input_ids) + # embed positions + positions = self.embed_positions(input_ids, past_key_values_length=past_key_values_length) + hidden_states = inputs_embeds + positions + + past_key_value_cache = [] + for idx, decoder_layer in enumerate(self.layers): + past_key_value = (past_key_values[4 * idx], past_key_values[4 * idx + 1]) \ + if past_key_values is not None else None + cross_past_key_value = (past_key_values[4 * idx + 2], past_key_values[4 * idx + 3]) \ + if past_key_values is not None else None + + layer_outputs = decoder_layer( + hidden_states, + actual_seq_len=actual_seq_len, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + layer_head_mask=None, + cross_attn_layer_head_mask=None, + past_key_value=past_key_value, + cross_attn_past_key_value=cross_past_key_value, + output_attentions=None, + use_cache=use_cache) + + hidden_states = layer_outputs[0] + past_key_value_cache.extend(layer_outputs[1]) + + hidden_states = self.layer_norm(hidden_states) + return hidden_states, past_key_value_cache + + +class MindieWhisperEncoderLayer(WhisperEncoderLayer): + def __init__(self, config): + super().__init__(config) + self.self_attn = MindiePFA( + embed_dim=self.embed_dim, + num_heads=config.encoder_attention_heads, + dropout=config.attention_dropout, + config=config, + ) + + +class MindieWhisperEncoder(WhisperEncoder): + + def __init__(self, config): + super().__init__(config) + self.layers = nn.ModuleList([MindieWhisperEncoderLayer(config) for _ in range(config.encoder_layers)]) + + +class MindieWhisperModel(WhisperModel): + + def __init__(self, config): + super().__init__(config) + self.decoder = MindieWhisperDecoder(config) + self.encoder = MindieWhisperEncoder(config) + + def forward( + self, + encoder_outputs, + decoder_input_ids, + past_key_values: Optional[torch.Tensor] = None, + actual_seq_len: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = False, + return_dict: Optional[bool] = True, + input_features: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None + ) -> List[torch.Tensor]: + if input_features is None and encoder_outputs is None: + raise ValueError("Make sure to provide either `input_features` or `encoder_outputs` to `forward`.") + + if encoder_outputs is None: + input_features = self._mask_input_features(input_features, attention_mask=attention_mask) + + encoder_outputs = self.encoder( + input_features, + head_mask=None, + output_attentions=None, + output_hidden_states=None, + return_dict=return_dict, + ) + + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + encoder_hidden_states=encoder_outputs, + past_key_values=past_key_values, + use_cache=use_cache, + actual_seq_len=actual_seq_len + ) + return decoder_outputs + + +class MindieWhisperForConditionalGeneration(WhisperForConditionalGeneration, GenerationMixin): + + def __init__(self, config, is_incre_decode=False, machine_type="800A2"): + super().__init__(config) + config.is_incre_decode = is_incre_decode + config.machine_type = machine_type + self.model = MindieWhisperModel(config) + self.has_load = False + self.has_compile = False + self.mindie_encoder = None + self.mindie_decoder_prefill = None + self.mindie_decoder = None + self.self_attn_scatter = None + self.encoder_attn_scatter = None + self.save_path = None + self.encoder_seq_len = 1500 + self.file_prefix_names = ["mindie_whisper_encoder_bs", + "mindie_decoder_prefill_bs", + "mindie_whisper_decoder_bs", + "mindie_self_scatter_bs", + "mindie_encoder_scatter_bs" + ] + self.hidden_size = 1280 + self.head_num = 20 + self.head_size = 64 + self.seq_len = 1 + self.layer_nums = 32 + self.max_len = 448 + self.past_key_value = [] + + def forward(self, *args): + if len(args) not in (2, 131): + raise ValueError(f"The args length of forward can only be 2 or 131, but got {len(args)}") + decoder_input_ids = args[0] + encoder_outputs = args[1] + if len(args) == 131: + actual_seq_len = args[2] + past_key_values = args[3:] + else: + actual_seq_len = None + past_key_values = None + outputs = self.model( + decoder_input_ids=decoder_input_ids, + encoder_outputs=encoder_outputs, + past_key_values=past_key_values, + actual_seq_len=actual_seq_len, + use_cache=True, + return_dict=False, + input_features=None + ) + lm_logits = self.proj_out(outputs[0]) + return [lm_logits] + outputs[1] + + def compile_encoder(self, args): + encoder = self.get_encoder() + class Encoder(torch.nn.Module): + + def __init__(self, model): + super().__init__() + self.model = model + + def forward(self, input_features): + return self.model(input_features=input_features, return_dict=False) + + mel_feature_size = 128 + max_frames = 3000 + input_features = torch.randn([args.bs, mel_feature_size, max_frames]) + encoder_traced = torch.jit.trace(Encoder(encoder), (input_features)) + input_info = [mindietorch.Input(shape=(args.bs, mel_feature_size, max_frames))] + print("Start AOE optimization_level 1.") + compiled = mindietorch.compile(encoder_traced, + inputs=input_info, + precision_policy=mindietorch.PrecisionPolicy.FP16, + soc_version=args.soc_version, + optimization_level=1 + ) + print("AOE optimize finished and start compile encoder.") + compiled = mindietorch.compile(encoder_traced, + inputs=input_info, + precision_policy=mindietorch.PrecisionPolicy.FP16, + soc_version=args.soc_version, + ) + save_file = os.path.join(args.save_path, f"{self.file_prefix_names[0]}{args.bs}.ts") + torch.jit.save(compiled, save_file) + print(f"Compile encoder success, saved in {save_file}") + + def compile_prefill_decoder(self, args): + print("Start compiling prefill_decoder.") + encoder_seq_len = 1500 + hidden_size = 1280 + encoder_outputs = torch.randn([args.bs, encoder_seq_len, hidden_size]) + decoder_input_ids = torch.randint(1, 4, (args.bs, 1)) + prefill_decoder_traced = torch.jit.trace(self.eval(), (decoder_input_ids, encoder_outputs)) + input_info = [mindietorch.Input(shape=(args.bs, 1), dtype=dtype.INT64), + mindietorch.Input(shape=(args.bs, encoder_seq_len, hidden_size))] + prefill_decoder_compiled = mindietorch.compile(prefill_decoder_traced, + inputs=input_info, + precision_policy=mindietorch.PrecisionPolicy.FP16, + soc_version=args.soc_version) + + save_file = os.path.join(args.save_path, f"{self.file_prefix_names[1]}{args.bs}.ts") + torch.jit.save(prefill_decoder_compiled, save_file) + print(f"Compile prefill_decoder success, saved in {save_file}.") + + def compile_decoder(self, args): + + class Decoder(torch.nn.Module): + + def __init__(self, model): + super().__init__() + self.model = model + + def forward(self, *args): + return self.model.forward(*args)[0] + + mindie_whisper = MindieWhisperForConditionalGeneration.from_pretrained(args.model_path, + is_incre_decode=True) + decoder = Decoder(mindie_whisper) + print("Start compiling decoder.") + + encoder_outputs = torch.randn([args.bs, self.encoder_seq_len, self.hidden_size]) + decoder_input_ids = torch.randint(1, 4, (args.bs, 1)) + actual_seq_len = torch.ones((args.bs)) + all_past_key_value = [torch.randn([args.bs, self.seq_len, self.head_num, self.head_size]), + torch.randn([args.bs, self.seq_len, self.head_num, self.head_size]), + torch.randn([args.bs, self.encoder_seq_len, self.head_num, self.head_size]), + torch.randn([args.bs, self.encoder_seq_len, self.head_num, self.head_size]) + ] * self.layer_nums + traced_args = [decoder_input_ids, encoder_outputs, actual_seq_len] + traced_args.extend(all_past_key_value) + traced_decoder = torch.jit.trace(decoder, traced_args) + # BSND + key_value_infos = [ + mindietorch.Input(shape=(args.bs, self.max_len, self.head_num, self.head_size), + dtype=dtype.FLOAT16), + mindietorch.Input(shape=(args.bs, self.max_len, self.head_num, self.head_size), + dtype=dtype.FLOAT16), + mindietorch.Input(shape=(args.bs, self.encoder_seq_len, self.head_num, self.head_size), + dtype=dtype.FLOAT16), + mindietorch.Input(shape=(args.bs, self.encoder_seq_len, self.head_num, self.head_size), + dtype=dtype.FLOAT16 + )] * self.layer_nums + input_info = [mindietorch.Input(shape=(args.bs, 1), dtype=dtype.INT64), + mindietorch.Input(shape=(args.bs, self.encoder_seq_len, self.hidden_size)), + mindietorch.Input(shape=(args.bs, ), dtype=dtype.INT64)] + + input_info.extend(key_value_infos) + float_size = 4 + voc_size = 51866 + buffer_size = math.ceil((args.bs * 1 * voc_size * float_size) / 1024 / 1024) + print(f"Set {buffer_size}/MB for output.") + compiled_decoder = mindietorch.compile(traced_decoder, + inputs=input_info, + precision_policy=mindietorch.PrecisionPolicy.FP16, + soc_version=args.soc_version, + default_buffer_size_vec=[buffer_size]) + save_file = os.path.join(args.save_path, f"{self.file_prefix_names[2]}{args.bs}.ts") + torch.jit.save(compiled_decoder, save_file) + + print(f"Compile whisper_decoder success, saved in {save_file}.") + + def compile_scatter_update(self, args): + + class MindieScatter(torch.nn.Module): + def forward(self, past_key_value, indices, update_states): + out = torch.ops.aie.scatter_update(past_key_value, indices, update_states, 1) + return out + bs = args.bs + self_past_key_value = torch.randn([bs, self.max_len, self.head_num, self.head_size]) + encoder_past_key_value = torch.randn([bs, self.max_len, self.head_num, self.head_size]) + indices = torch.tensor([0]*bs) + update_states = torch.randn([bs, 1, 20, 64]) + traced = torch.jit.trace(MindieScatter(), (self_past_key_value, indices, update_states)) + + self_attn_info = mindietorch.Input(shape=self_past_key_value.shape, dtype=mindietorch.dtype.FLOAT16) + encoder_attn_info = mindietorch.Input(shape=encoder_past_key_value.shape, dtype=mindietorch.dtype.FLOAT16) + indices_info = mindietorch.Input(shape=indices.shape, dtype=mindietorch.dtype.INT64) + update_states_info = mindietorch.Input(shape=encoder_past_key_value.shape, dtype=mindietorch.dtype.FLOAT16) + + compile_self = mindietorch.compile(traced, inputs=[self_attn_info, indices_info, update_states_info], + soc_version=args.soc_version) + torch.jit.save(compile_self, f"{args.save_path}/{self.file_prefix_names[3]}{args.bs}.ts") + + compile_self = mindietorch.compile(traced, inputs=[encoder_attn_info, indices_info, update_states_info], + soc_version=args.soc_version) + torch.jit.save(compile_self, f"{args.save_path}/{self.file_prefix_names[4]}{args.bs}.ts") + print("compile scatter success.") + + def compile(self, args): + print("Start compiling Mindie-Whisper, it will take some time, please wait.") + save_path = args.save_path + + if not save_path: + raise ValueError("Please provide the directory where the compiled model saved.") + if not os.path.exists(save_path): + os.makedirs(save_path) + print(f"Directory {save_path} created.") + else: + print(f"Directory {save_path} already exists.") + self.compile_encoder(args) + self.compile_prefill_decoder(args) + self.compile_decoder(args) + self.compile_scatter_update(args) + self.has_compile = True + + def _init_past_key_value_cache(self, bsz): + for _ in range(32): + self.past_key_value.append(torch.ones([bsz, self.max_len, self.head_num, self.head_size])) + self.past_key_value.append(torch.ones([bsz, self.max_len, self.head_num, self.head_size])) + self.past_key_value.append(torch.ones([bsz, self.encoder_seq_len, self.head_num, self.head_size])) + self.past_key_value.append(torch.ones([bsz, self.encoder_seq_len, self.head_num, self.head_size])) + print("init past key value cache success.") + + def load_mindie_models(self, save_path, batch_size): + if not (save_path and batch_size): + raise ValueError(f"Please provide batch_size and the directory where the compiled models saved,\ + but found save_path is {save_path}, batch_size is{batch_size}.") + self._check_save_path(save_path, batch_size) + self.batch_size = batch_size + + self._init_past_key_value_cache(batch_size) + + + if not self.has_load: + self.mindie_encoder = torch.jit.load(f"{save_path}/{self.file_prefix_names[0]}{batch_size}.ts") + print(f"load {self.file_prefix_names[0]}{batch_size}.ts success.") + + self.mindie_decoder_prefill = torch.jit.load(f"{save_path}/{self.file_prefix_names[1]}{batch_size}.ts") + print(f"load {self.file_prefix_names[1]}{batch_size}.ts success.") + + self.mindie_decoder = torch.jit.load(f"{save_path}/{self.file_prefix_names[2]}{batch_size}.ts") + print(f"load {self.file_prefix_names[2]}{batch_size}.ts success.") + + self.self_attn_scatter = torch.jit.load(f"{save_path}/{self.file_prefix_names[3]}{batch_size}.ts") + print(f"load {self.file_prefix_names[3]}{batch_size}.ts success.") + + self.self_attn_scatter = torch.jit.load(f"{save_path}/{self.file_prefix_names[4]}{batch_size}.ts") + print(f"load {self.file_prefix_names[4]}{batch_size}.ts success.") + self.has_load = True + else: + print("Mindie whisper has already load.") + + def greedy_search( + self, + input_ids: torch.LongTensor, + logits_processor: Optional[LogitsProcessorList] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, + max_length: Optional[int] = None, + pad_token_id: Optional[int] = None, + eos_token_id: Optional[Union[int, List[int]]] = None, + output_scores: Optional[bool] = None, + return_dict_in_generate: Optional[bool] = None, + synced_gpus: bool = False, + **model_kwargs): + logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() + stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() + if max_length is not None: + print( + "`max_length` is deprecated in this function, use" + " `stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])` instead." + ) + stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) + pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id + eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id + if isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] + eos_token_id_tensor = torch.tensor(eos_token_id).to("cpu") if eos_token_id is not None else None + output_scores = output_scores if output_scores is not None else self.generation_config.output_scores + # init attention / hidden states / scores tuples + scores = () if (return_dict_in_generate and output_scores) else None + + # keep track of which sequences are already finished + unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device="cpu") + + this_peer_finished = False # used by synced_gpus only + kv_actual_step = 1 + indices = torch.tensor([0] * input_ids.shape[0]) + is_first_step = True + while True: + + model_inputs = self.prepare_inputs_for_generation(input_ids, is_first_step, **model_kwargs) + args = [model_inputs["decoder_input_ids"].contiguous().to("npu"), model_inputs["encoder_outputs"]] + if is_first_step: + outputs = self.mindie_decoder_prefill(*args) + for idx in range(32): + self.self_attn_scatter(self.past_key_value[0], indices, outputs[1 + 4*idx]) + self.self_attn_scatter(self.past_key_value[1], indices, outputs[1 + 4*idx + 1]) + self.encoder_attn_scatter(self.past_key_value[2], indices, outputs[1 + 4*idx + 2]) + self.encoder_attn_scatter(self.past_key_value[3], indices, outputs[1 + 4*idx + 3]) + is_first_step = False + else: + kv_actual_step += 1 + args.append(torch.tensor([kv_actual_step] * input_ids.shape[0])).to("npu") + args.extend(self.past_key_value) + outputs = self.mindie_decoder(*args) + if isinstance(outputs, list): + next_token_logits = outputs[0].to("cpu")[:, -1, :] + else: + next_token_logits = outputs.to("cpu")[:, -1, :] + + # pre-process distribution + next_tokens_scores = logits_processor(input_ids.to("cpu"), next_token_logits) + + # argmax + next_tokens = torch.argmax(next_tokens_scores, dim=-1) + + # finished sentences should have their next token be a padding token + if eos_token_id is not None: + if pad_token_id is None: + raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.") + next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) + + # update generated ids, model inputs, and length for next step + input_ids = torch.cat([input_ids.to("cpu"), next_tokens[:, None]], dim=-1) + + model_kwargs = self._update_model_kwargs_for_generation( + outputs, model_kwargs + ) + + # if eos_token was found in one sentence, set sentence to finished + if eos_token_id_tensor is not None: + unfinished_sequences = unfinished_sequences.mul( + next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0) + ) + + # stop when each sentence is finished + if unfinished_sequences.max() == 0: + this_peer_finished = True + + # stop if we exceed the maximum length + if stopping_criteria(input_ids, scores): + this_peer_finished = True + + if this_peer_finished and not synced_gpus: + break + + return input_ids + + def generate( + self, + input_features: Optional[torch.Tensor] = None, + generation_config=None, + logits_processor=None, + stopping_criteria=None, + prefix_allowed_tokens_fn=None, + synced_gpus=False, + return_dict_in_generate: Optional[bool] = None, + **kwargs, + ): + if generation_config is None: + generation_config = copy.deepcopy(self.generation_config) + num_segment_frames = 3000 + assert input_features.shape[-1] == num_segment_frames, "Only support 30s speech." + encoder_outputs = self.mindie_encoder(input_features.to("npu"))[0] + kwargs["encoder_outputs"] = encoder_outputs + outputs = GenerationMixin.generate( + self, + input_features, + generation_config, + logits_processor, + stopping_criteria, + prefix_allowed_tokens_fn, + synced_gpus, + return_dict_in_generate=return_dict_in_generate, + **kwargs + ) + return outputs + + def _update_model_kwargs_for_generation( + self, + outputs, + model_inputs, + ): + return model_inputs + + def _maybe_initialize_input_ids_for_generation( + self, + inputs: Optional[torch.Tensor] = None, + bos_token_id: Optional[int] = None, + model_kwargs: Optional[Dict[str, torch.Tensor]] = None, + ) -> torch.LongTensor: + """Initializes input ids for generation, if necessary.""" + if inputs is not None: + return inputs + + encoder_outputs = model_kwargs.get("encoder_outputs") + if self.config.is_encoder_decoder and encoder_outputs is not None: + # make dummy input_ids with value -100, as a sanity check ensuring that they won't be used for encoding + shape = encoder_outputs.size()[:-1] + input_ids = torch.ones(shape, dtype=torch.long, device="cpu") * -100 + return input_ids.to("npu") + + if bos_token_id is None: + raise ValueError("`bos_token_id` has to be defined when no `input_ids` are provided.") + + # If there is some tensor in `model_kwargs`, we can infer the batch size from it. This is helpful with + # soft-prompting or in multimodal implementations built on top of decoder-only language models. + batch_size = 1 + for value in model_kwargs.values(): + if isinstance(value, torch.Tensor): + batch_size = value.shape[0] + break + return torch.ones((batch_size, 1), dtype=torch.long, device=self.device) * bos_token_id + + def _prepare_decoder_input_ids_for_generation( + self, + batch_size: int, + model_input_name: str, + model_kwargs: Dict[str, torch.Tensor], + decoder_start_token_id: int = None, + bos_token_id: int = None, + device: torch.device = None, + ) -> Tuple[torch.LongTensor, Dict[str, torch.Tensor]]: + """Prepares `decoder_input_ids` for generation with encoder-decoder models""" + # 1. Check whether the user has defined `decoder_input_ids` manually. To facilitate in terms of input naming, + # we also allow the user to pass it under `input_ids`, if the encoder does not use it as the main input. + if model_kwargs is not None and "decoder_input_ids" in model_kwargs: + decoder_input_ids = model_kwargs.pop("decoder_input_ids") + elif "input_ids" in model_kwargs and model_input_name != "input_ids": + decoder_input_ids = model_kwargs.pop("input_ids") + else: + decoder_input_ids = None + + # 2. Encoder-decoder models expect the `decoder_input_ids` to start with a special token. Let's ensure that. + decoder_start_token_id = self._get_decoder_start_token_id(decoder_start_token_id, bos_token_id) + if device is None: + device = self.device + decoder_input_ids_start = torch.ones((batch_size, 1), dtype=torch.long, device="cpu") * decoder_start_token_id + + # no user input -> use decoder_start_token_id as decoder_input_ids + if decoder_input_ids is None: + decoder_input_ids = decoder_input_ids_start + # exception: Donut checkpoints have task-specific decoder starts and don't expect a BOS token + elif self.config.model_type == "vision-encoder-decoder" and "donut" in self.name_or_path.lower(): + pass + # user input but doesn't start with decoder_start_token_id -> prepend decoder_start_token_id (and adjust + # decoder_attention_mask if provided) + elif (decoder_input_ids[:, 0] != decoder_start_token_id).all().item(): + decoder_input_ids = torch.cat([decoder_input_ids_start, decoder_input_ids], dim=-1) + if "decoder_attention_mask" in model_kwargs: + decoder_attention_mask = model_kwargs["decoder_attention_mask"] + decoder_attention_mask = torch.cat( + (torch.ones_like(decoder_attention_mask)[:, :1], decoder_attention_mask), + dim=-1, + ) + model_kwargs["decoder_attention_mask"] = decoder_attention_mask.to(self.device) + + return decoder_input_ids.to("npu"), model_kwargs + + def _check_save_path(self, save_path, batch_size): + file_list = os.listdir(save_path) + expected_files = [file + f"{batch_size}.ts" for file in self.file_prefix_names] + for file in expected_files: + if file not in file_list: + raise ValueError(f"Expected file name is {file}, but can't be found in path: {save_path}") + + def prepare_inputs_for_generation( + self, + decoder_input_ids, + if_first_step, + encoder_outputs=None, + **kwargs + ): + if not if_first_step: + decoder_input_ids_shape = decoder_input_ids.shape + remove_prefix_length = decoder_input_ids_shape[1] - 1 + decoder_input_ids = decoder_input_ids[:, remove_prefix_length] + return { + "encoder_outputs": encoder_outputs, + "decoder_input_ids": decoder_input_ids + } \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/audio/whisperX/readme.md b/MindIE/MindIE-Torch/built-in/audio/whisperX/readme.md new file mode 100644 index 0000000000000000000000000000000000000000..c9f020d143c74ee7ff18453c53925232e8dfd412 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/audio/whisperX/readme.md @@ -0,0 +1,80 @@ +# Whisper-large-v3模型-推理指导 + +- [概述](#概述) +- [推理环境准备](#推理环境准备) +- [快速上手](#快速上手) + - [获取源码](#获取源码) + - [模型推理](#模型推理) + +# 概述 + +该工程使用mindietorch部署whisper-large-v3模型 + + +# 推理环境准备 + +- 该模型需要以下插件与驱动 + + **表 1** 版本配套表 + + | 配套 | 版本 | 环境准备指导 | + |-----------| ------- | ------------ | + | Python | 3.10.13 | - | + | torch | 2.1.0+cpu | - | + | torch_audio | 2.1.0+cpu | - | + | CANN | 8.0.RC3 | - | + | MindIE | 1.0.RC3 | - | + +# 快速上手 +## 获取源码 + +1. 安装mindie包 + + ```bash + # 安装mindie + chmod +x ./Ascend-mindie_xxx.run + ./Ascend-mindie_xxx.run --install + source /usr/local/Ascend/mindie/set_env.sh + ``` + + +2. 模型权重下载路径: + ```bash + https://huggingface.co/openai/whisper-large-v3/tree/main + ``` + 将权重文件存放至当前目录下的model_path文件夹,请先创建改文件夹。 + + +5. 安装依赖 + ``` + pip3 install torch==2.1.0 --index-url https://download.pytorch.org/whl/cpu + pip3 install nltk + pip3 install librosa + pip3 install transformers==4.36.0 + pip3 install numpy==1.26.0 + ``` + +## 模型推理 +1. 设置mindie内存池上限为32,执行如下命令设置环境变量。内存池设置过小,内存重复申请和释放会影响性能。 + ``` + export TORCH_AIE_NPU_CACHE_MAX_SIZE=32 + ``` + +2. 模型编译和推理 + ``` + python3 compile_whisper.py \ + -model_path ./model_path \ + -bs 16 \ + -save_path ./compiled_models \ + -soc_version Ascend310P3 + ``` + + 参数说明: + - -model_path:预训练模型路径,必选。 + - -bs:batch_size, 默认值为8, 可选。 + - -save_path: 编译好的模型的保存文件,必选。 + - -device_id: 选在模型运行的卡编号,默认值0,可选。 + - -soc_version: 芯片类型,默认值:Ascend310P3,仅支持配置Ascend310P3或者Ascend910B4。 + + 约束说明: + 1. 当前暂不支持动态batch,batch_size改变后,需要重新编图。 diff --git a/MindIE/MindIE-Torch/built-in/audio/whisper_large_v3/config.py b/MindIE/MindIE-Torch/built-in/audio/whisper_large_v3/config.py new file mode 100644 index 0000000000000000000000000000000000000000..08d322455aaed00b086e749e59913e917e78c3cb --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/audio/whisper_large_v3/config.py @@ -0,0 +1,19 @@ + +from dataclasses import dataclass, field +from typing import List + +@dataclass +class CompileInfo: + prefix_name = ["mindie_whisper_encoder_bs", + "mindie_decoder_prefill_bs", + "mindie_whisper_decoder_bs", + "mindie_self_scatter_bs", + "mindie_encoder_scatter_bs"] + mel_feature_size = 128 + max_frames = 3000 + max_decode_step = 448 + head_num = 20 + head_size = 64 + encoder_seq_len = 1500 + hidden_size = 1280 + layer_nums = 32 diff --git a/MindIE/MindIE-Torch/built-in/audio/whisper_large_v3/mindie_whisper.py b/MindIE/MindIE-Torch/built-in/audio/whisper_large_v3/mindie_whisper.py new file mode 100644 index 0000000000000000000000000000000000000000..4694e441913f6a7a859f3aff6007bd98501659a5 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/audio/whisper_large_v3/mindie_whisper.py @@ -0,0 +1,772 @@ +import argparse +import os +import copy +import math +import warnings +import librosa +from typing import Optional, Tuple, Union, List, Dict +from collections import OrderedDict + +import torch +from torch.utils.data import Dataset +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn +from transformers import WhisperForConditionalGeneration +from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask +from transformers.models.whisper.modeling_whisper import WhisperDecoder, WhisperEncoder,WhisperEncoderLayer,\ +WhisperModel,WhisperDecoderLayer,WhisperAttention +from transformers.generation.logits_process import LogitsProcessorList +from transformers.generation.stopping_criteria import StoppingCriteriaList, validate_stopping_criteria +from transformers.generation.utils import GenerationMixin +import mindietorch +from mindietorch._enums import dtype + + +class MindieAttention(WhisperAttention): + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: float = 0.0, + is_decoder: bool = False, + bias: bool = True, + is_causal: bool = False, + config = None + ): + super().__init__( + embed_dim, + num_heads, + dropout, + is_decoder, + bias, + is_causal, + config + ) + + # BSND + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + **kwargs + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + + bsz, tgt_len, _ = hidden_states.size() + query_states = self.q_proj(hidden_states) + + # self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + + # B S N D + query_states = query_states.view(bsz, tgt_len, self.num_heads, self.head_dim).contiguous() + + # B S N D + attn_output = torch.ops.aie.flash_attention( + query = query_states, + key = key_states, + value = value_states, + num_head = self.num_heads, + scale = self.scaling, + layout = "BSND", + type = "PFA" + ) + + attn_output = attn_output.view(bsz, tgt_len, self.num_heads, self.head_dim) + attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + attn_output = self.out_proj(attn_output) + return attn_output, _, _ + + +class MindieWhisperDecoderLayer(WhisperDecoderLayer): + + def __init__(self, config): + super().__init__(config) + self.embed_dim = config.d_model + + def forward( + self, + hidden_states, + attention_mask, + encoder_hidden_states, + layer_head_mask, + cross_attn_layer_head_mask, + past_key_value, + cross_attn_past_key_value, + output_attentions, + use_cache, + encoder_attention_mask=None + ) -> torch.Tensor: + residual = hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + + # Self Attention + self_attn_past_key_value = past_key_value if past_key_value is not None else None + hidden_states, _, present_key_value = self.self_attn( + hidden_states=hidden_states, + past_key_value=self_attn_past_key_value, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + output_attentions=output_attentions, + ) + hidden_states = residual + hidden_states + + # Cross-Attention Block + cross_attn_present_key_value = None + if encoder_hidden_states is not None: + residual = hidden_states + hidden_states = self.encoder_attn_layer_norm(hidden_states) + + hidden_states, _, cross_attn_present_key_value = self.encoder_attn( + hidden_states=hidden_states, + key_value_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + layer_head_mask=cross_attn_layer_head_mask, + past_key_value=cross_attn_past_key_value, + output_attentions=output_attentions, + ) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + + # add cross-attn to positions 3,4 of present_key_value tuple + present_key_value = present_key_value + cross_attn_present_key_value + + # Fully Connected + residual = hidden_states + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = hidden_states.reshape(-1, self.embed_dim) + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = self.fc2(hidden_states) + hidden_states = hidden_states.reshape(-1, 1, self.embed_dim) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +class MindieWhisperDecoder(WhisperDecoder): + + def __init__(self, config): + super().__init__(config) + self.layers = nn.ModuleList([MindieWhisperDecoderLayer(config) for _ in range(config.decoder_layers)]) + self.config = config + + def forward( + self, + input_ids, + encoder_hidden_states, + past_key_values, + use_cache=True, + attention_mask=None): + if input_ids is None: + raise ValueError("You have to specify either decoder_input_ids") + inputs_embeds = self.embed_tokens(input_ids) + + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + past_key_values_length = past_key_values[0].shape[2] if past_key_values is not None else 0 + + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, input_shape, inputs_embeds, past_key_values_length + ) + inputs_embeds = self.embed_tokens(input_ids) + # embed positions + positions = self.embed_positions(input_ids, past_key_values_length=past_key_values_length) + hidden_states = inputs_embeds + positions + + past_key_value_cache = [] + for idx, decoder_layer in enumerate(self.layers): + + past_key_value = (past_key_values[4*idx], past_key_values[4*idx + 1])\ + if past_key_values is not None else None + cross_past_key_value = (past_key_values[4*idx + 2], past_key_values[4*idx + 3])\ + if past_key_values is not None else None + + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + layer_head_mask=None, + cross_attn_layer_head_mask=None, + past_key_value=past_key_value, + cross_attn_past_key_value=cross_past_key_value, + output_attentions=None, + use_cache=use_cache) + + hidden_states = layer_outputs[0] + past_key_value_cache.extend(layer_outputs[1]) + + hidden_states = self.layer_norm(hidden_states) + return hidden_states, past_key_value_cache + +class MindieWhisperEncoderLayer(WhisperEncoderLayer): + def __init__(self, config): + super().__init__(config) + self.self_attn = MindieAttention( + embed_dim=self.embed_dim, + num_heads=config.encoder_attention_heads, + dropout=config.attention_dropout, + config=config, + ) + + +class MindieWhisperEncoder(WhisperEncoder): + + def __init__(self, config): + super().__init__(config) + self.layers = nn.ModuleList([MindieWhisperEncoderLayer(config) for _ in range(config.encoder_layers)]) + + +class MindieWhisperModel(WhisperModel): + + def __init__(self, config): + super().__init__(config) + self.decoder = MindieWhisperDecoder(config) + self.encoder = MindieWhisperEncoder(config) + + def forward( + self, + encoder_outputs, + decoder_input_ids, + past_key_values, + use_cache: Optional[bool] = False, + return_dict: Optional[bool] = True, + input_features: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None + ) -> List[torch.Tensor]: + if input_features is None and encoder_outputs is None: + raise ValueError("Make sure to provide either `input_features` or `encoder_outputs` to `forward`.") + + if encoder_outputs is None: + input_features = self._mask_input_features(input_features, attention_mask=attention_mask) + + encoder_outputs = self.encoder( + input_features, + head_mask=None, + output_attentions=None, + output_hidden_states=None, + return_dict=return_dict, + ) + + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + encoder_hidden_states=encoder_outputs, + past_key_values=past_key_values, + use_cache=use_cache + ) + return decoder_outputs + + +class MindieWhisperForConditionalGeneration(WhisperForConditionalGeneration, GenerationMixin): + + def __init__(self, config): + super().__init__(config) + self.model = MindieWhisperModel(config) + self.has_load = False + self.has_compile = False + self.mindie_encoder = None + self.mindie_decoder_prefill = None + self.mindie_decoder = None + self.save_path = None + self.batch_size = -1 + self.encoder_seq_len = 1500 + self.init_encoder_seq_len = self.encoder_seq_len - 1 + self.file_prefix_names = ["mindie_whisper_encoder_bs", + "mindie_decoder_prefill_bs", + "mindie_whisper_decoder_bs" + ] + + def compile(self, args): + print("Start compiling Mindie-Whisper, it will take some time, please wait.") + save_path = args.save_path + model_path = args.model_path + + if not save_path: + raise ValueError("Please provide the directory where the compiled model saved.") + if not os.path.exists(save_path): + os.makedirs(save_path) + print(f"Directory {save_path} created.") + else: + print(f"Directory {save_path} already exists.") + + encoder = self.get_encoder() + + class Encoder(torch.nn.Module): + + def __init__(self, model): + super().__init__() + self.model = model + + def forward(self, input_features): + return self.model(input_features=input_features, return_dict=False) + + mel_feature_size = 128 + max_frames = 3000 + input_features = torch.randn([args.bs, mel_feature_size, max_frames]) + encoder_traced = torch.jit.trace(Encoder(encoder), (input_features)) + input_info = [mindietorch.Input(shape=(args.bs, mel_feature_size, max_frames))] + print("Start AOE optimization_level 1.") + compiled = mindietorch.compile(encoder_traced, + inputs=input_info, + precision_policy=mindietorch.PrecisionPolicy.FP16, + soc_version=args.soc_version, + optimization_level=1 + ) + print("AOE optimize finished and start compile encoder.") + compiled = mindietorch.compile(encoder_traced, + inputs=input_info, + precision_policy=mindietorch.PrecisionPolicy.FP16, + soc_version=args.soc_version, + ) + save_file = os.path.join(save_path, f"{self.file_prefix_names[0]}{args.bs}.ts") + torch.jit.save(compiled, save_file) + + traced_file = os.path.join(save_path, f"{self.file_prefix_names[0]}{args.bs}.pt") + torch.jit.save(encoder_traced, traced_file) + print(f"Compile encoder success, saved in {save_file}") + + print("Start compiling prefill_decoder.") + encoder_seq_len = 1500 + hidden_size = 1280 + encoder_outputs = torch.randn([args.bs, encoder_seq_len, hidden_size]) + decoder_input_ids = torch.randint(1, 4, (args.bs, 1)) + prefill_decoder_traced = torch.jit.trace(self.eval(), (decoder_input_ids, encoder_outputs)) + input_info = [mindietorch.Input(shape=(args.bs, 1), dtype=dtype.INT64), + mindietorch.Input(shape=(args.bs, encoder_seq_len, hidden_size))] + prefill_decoder_compiled = mindietorch.compile(prefill_decoder_traced, + inputs=input_info, + precision_policy=mindietorch.PrecisionPolicy.FP16, + soc_version=args.soc_version) + + traced_file = os.path.join(save_path, f"{self.file_prefix_names[1]}{args.bs}.pt") + torch.jit.save(prefill_decoder_traced, traced_file) + + save_file = os.path.join(save_path, f"{self.file_prefix_names[1]}{args.bs}.ts") + torch.jit.save(prefill_decoder_compiled, save_file) + print(f"Compile prefill_decoder success, saved in {save_file}.") + + print("Start compiling decoder.") + hidden_size = 1280 + head_num = 20 + head_size = 64 + seq_len = 1 + layer_nums = 32 + max_len = 448 + + encoder_outputs = torch.randn([args.bs, self.encoder_seq_len, hidden_size]) + decoder_input_ids = torch.randint(1, 4, (args.bs, 1)) + all_past_key_value = [torch.randn([args.bs, head_num, seq_len, head_size]), + torch.randn([args.bs, head_num, seq_len, head_size]), + torch.randn([args.bs, head_num, self.encoder_seq_len, head_size]), + torch.randn([args.bs, head_num, self.encoder_seq_len, head_size])] * layer_nums + traced_args = [decoder_input_ids, encoder_outputs] + traced_args.extend(all_past_key_value) + traced_decoder = torch.jit.trace(self.eval(), example_inputs=traced_args) + # if using PFA, the BSND layout needed + key_value_infos = [mindietorch.Input( + min_shape=(args.bs, head_num, 1, head_size), + max_shape=(args.bs, head_num, max_len, head_size), + dtype=dtype.FLOAT16 + ), + mindietorch.Input( + min_shape=(args.bs, head_num, 1, head_size), + max_shape=(args.bs, head_num, max_len, head_size), + dtype=dtype.FLOAT16 + ), + mindietorch.Input( + min_shape=(args.bs, head_num, self.encoder_seq_len, head_size), + max_shape=(args.bs, head_num, self.encoder_seq_len, head_size), + dtype=dtype.FLOAT16 + ), + mindietorch.Input( + min_shape=(args.bs, head_num, self.encoder_seq_len, head_size), + max_shape=(args.bs, head_num, self.encoder_seq_len, head_size), + dtype=dtype.FLOAT16 + ) + ] * layer_nums + input_info = [mindietorch.Input(min_shape=(args.bs, 1), + max_shape=(args.bs, 1), dtype=dtype.INT64 + ), + mindietorch.Input(min_shape=(args.bs, self.encoder_seq_len, hidden_size), + max_shape=(args.bs, self.encoder_seq_len, hidden_size) + ) + ] + input_info.extend(key_value_infos) + float_size = 4 + # pre set malloc each output buffer size/MB + buffer_size = math.ceil((args.bs * self.encoder_seq_len * hidden_size * float_size) / 1024 / 1024) + print(f"Set {buffer_size}/MB for each output.") + compiled_decoder = mindietorch.compile(traced_decoder, + inputs=input_info, + precision_policy=mindietorch.PrecisionPolicy.FP16, + soc_version=args.soc_version, + default_buffer_size_vec=[buffer_size]) + + traced_file = os.path.join(save_path, f"{self.file_prefix_names[2]}{args.bs}.pt") + torch.jit.save(traced_decoder, traced_file) + + save_file = os.path.join(save_path, f"{self.file_prefix_names[2]}{args.bs}.ts") + torch.jit.save(compiled_decoder, save_file) + + print(f"Compile whisper_decoder success, saved in {save_file}.") + self.has_compile = True + + def load_mindie_models(self, save_path, batch_size): + if not (save_path and batch_size): + raise ValueError(f"Please provide batch_size and the directory where the compiled models saved,\ + but found save_path is {save_path}, batch_size is{batch_size}.") + self._check_save_path(save_path, batch_size) + self.batch_size = batch_size + # self._init_past_key_values(self.batch_size) + + if not self.has_load: + self.mindie_encoder = torch.jit.load(f"{save_path}/{self.file_prefix_names[0]}{batch_size}.ts") + print(f"load {self.file_prefix_names[0]}{batch_size}.ts success.") + + self.mindie_decoder_prefill = torch.jit.load(f"{save_path}/{self.file_prefix_names[1]}{batch_size}.ts") + print(f"load {self.file_prefix_names[1]}{batch_size}.ts success.") + + self.mindie_decoder = torch.jit.load(f"{save_path}/{self.file_prefix_names[2]}{batch_size}.ts") + print(f"load {self.file_prefix_names[2]}{batch_size}.ts success.") + self.has_load = True + else: + print("Mindie whisper has already load.") + + def forward(self, *args): + """For performance reasons, all layers'past_key_value need to be used as inputs and outputs.""" + if len(args) not in (2, 130): + raise ValueError(f"The args length of forward can only be 2 or 130, but got {len(args)}") + decoder_input_ids = args[0] + encoder_outputs = args[1] + past_key_values = args[2:] if len(args) == 130 else None + outputs = self.model( + decoder_input_ids=decoder_input_ids, + encoder_outputs=encoder_outputs, + past_key_values=past_key_values, + use_cache=True, + return_dict=False, + input_features=None + ) + lm_logits = self.proj_out(outputs[0]) + return [lm_logits] + outputs[1] + + def greedy_search( + self, + input_ids: torch.LongTensor, + logits_processor: Optional[LogitsProcessorList] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, + max_length: Optional[int] = None, + pad_token_id: Optional[int] = None, + eos_token_id: Optional[Union[int, List[int]]] = None, + output_scores: Optional[bool] = None, + return_dict_in_generate: Optional[bool] = None, + synced_gpus: bool = False, + **model_kwargs): + logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() + stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() + if max_length is not None: + print( + "`max_length` is deprecated in this function, use" + " `stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])` instead." + ) + stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) + pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id + eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id + if isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] + eos_token_id_tensor = torch.tensor(eos_token_id).to("cpu") if eos_token_id is not None else None + output_scores = output_scores if output_scores is not None else self.generation_config.output_scores + # init attention / hidden states / scores tuples + scores = () if (return_dict_in_generate and output_scores) else None + + # keep track of which sequences are already finished + unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device="cpu") + + this_peer_finished = False # used by synced_gpus only + + is_first_step = True + while True: + + model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) + args = [model_inputs["decoder_input_ids"].contiguous().to("npu"), model_inputs["encoder_outputs"]] + if is_first_step: + outputs = self.mindie_decoder_prefill(*args) + is_first_step = False + else: + args.extend(model_inputs["past_key_values"]) + outputs = self.mindie_decoder(*args) + if synced_gpus and this_peer_finished: + continue + + next_token_logits = outputs[0].to("cpu")[:, -1, :] + + # pre-process distribution + next_tokens_scores = logits_processor(input_ids.to("cpu"), next_token_logits) + + # argmax + next_tokens = torch.argmax(next_tokens_scores, dim=-1) + + # finished sentences should have their next token be a padding token + if eos_token_id is not None: + if pad_token_id is None: + raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.") + next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) + + # update generated ids, model inputs, and length for next step + input_ids = torch.cat([input_ids.to("cpu"), next_tokens[:, None]], dim=-1) + + model_kwargs = self._update_model_kwargs_for_generation( + outputs, model_kwargs + ) + + # if eos_token was found in one sentence, set sentence to finished + if eos_token_id_tensor is not None: + unfinished_sequences = unfinished_sequences.mul( + next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0) + ) + + # stop when each sentence is finished + if unfinished_sequences.max() == 0: + this_peer_finished = True + + # stop if we exceed the maximum length + if stopping_criteria(input_ids, scores): + this_peer_finished = True + + if this_peer_finished and not synced_gpus: + break + + return input_ids + + def generate( + self, + input_features: Optional[torch.Tensor] = None, + generation_config=None, + logits_processor=None, + stopping_criteria=None, + prefix_allowed_tokens_fn=None, + synced_gpus=False, + return_dict_in_generate: Optional[bool] = None, + **kwargs, + ): + """ + 1 only support speech < 30s + 2 only support encoder outputs as input + """ + if generation_config is None: + generation_config = copy.deepcopy(self.generation_config) + is_shortform = True + if is_shortform: + outputs = GenerationMixin.generate( + self, + input_features, + generation_config, + logits_processor, + stopping_criteria, + prefix_allowed_tokens_fn, + synced_gpus, + return_dict_in_generate=return_dict_in_generate, + **kwargs + ) + return outputs + + def _update_model_kwargs_for_generation( + self, + outputs, + model_inputs, + ): + # update past_key_values + model_inputs["past_key_values"] = outputs[1:] + return model_inputs + + def _maybe_initialize_input_ids_for_generation( + self, + inputs: Optional[torch.Tensor] = None, + bos_token_id: Optional[int] = None, + model_kwargs: Optional[Dict[str, torch.Tensor]] = None, + ) -> torch.LongTensor: + """Initializes input ids for generation, if necessary.""" + if inputs is not None: + return inputs + + encoder_outputs = model_kwargs.get("encoder_outputs") + if self.config.is_encoder_decoder and encoder_outputs is not None: + # make dummy input_ids with value -100, as a sanity check ensuring that they won't be used for encoding + shape = encoder_outputs.size()[:-1] + input_ids = torch.ones(shape, dtype=torch.long, device="cpu") * -100 + return input_ids.to("npu") + + if bos_token_id is None: + raise ValueError("`bos_token_id` has to be defined when no `input_ids` are provided.") + + # If there is some tensor in `model_kwargs`, we can infer the batch size from it. This is helpful with + # soft-prompting or in multimodal implementations built on top of decoder-only language models. + batch_size = 1 + for value in model_kwargs.values(): + if isinstance(value, torch.Tensor): + batch_size = value.shape[0] + break + return torch.ones((batch_size, 1), dtype=torch.long, device=self.device) * bos_token_id + + def _prepare_decoder_input_ids_for_generation( + self, + batch_size: int, + model_input_name: str, + model_kwargs: Dict[str, torch.Tensor], + decoder_start_token_id: int = None, + bos_token_id: int = None, + device: torch.device = None, + ) -> Tuple[torch.LongTensor, Dict[str, torch.Tensor]]: + """Prepares `decoder_input_ids` for generation with encoder-decoder models""" + # 1. Check whether the user has defined `decoder_input_ids` manually. To facilitate in terms of input naming, + # we also allow the user to pass it under `input_ids`, if the encoder does not use it as the main input. + if model_kwargs is not None and "decoder_input_ids" in model_kwargs: + decoder_input_ids = model_kwargs.pop("decoder_input_ids") + elif "input_ids" in model_kwargs and model_input_name != "input_ids": + decoder_input_ids = model_kwargs.pop("input_ids") + else: + decoder_input_ids = None + + # 2. Encoder-decoder models expect the `decoder_input_ids` to start with a special token. Let's ensure that. + decoder_start_token_id = self._get_decoder_start_token_id(decoder_start_token_id, bos_token_id) + if device is None: + device = self.device + decoder_input_ids_start = torch.ones((batch_size, 1), dtype=torch.long, device="cpu") * decoder_start_token_id + + # no user input -> use decoder_start_token_id as decoder_input_ids + if decoder_input_ids is None: + decoder_input_ids = decoder_input_ids_start + # exception: Donut checkpoints have task-specific decoder starts and don't expect a BOS token + elif self.config.model_type == "vision-encoder-decoder" and "donut" in self.name_or_path.lower(): + pass + # user input but doesn't start with decoder_start_token_id -> prepend decoder_start_token_id (and adjust + # decoder_attention_mask if provided) + elif (decoder_input_ids[:, 0] != decoder_start_token_id).all().item(): + decoder_input_ids = torch.cat([decoder_input_ids_start, decoder_input_ids], dim=-1) + if "decoder_attention_mask" in model_kwargs: + decoder_attention_mask = model_kwargs["decoder_attention_mask"] + decoder_attention_mask = torch.cat( + (torch.ones_like(decoder_attention_mask)[:, :1], decoder_attention_mask), + dim=-1, + ) + model_kwargs["decoder_attention_mask"] = decoder_attention_mask.to(self.device) + + return decoder_input_ids.to("npu"), model_kwargs + + def _check_save_path(self, save_path, batch_size): + file_list = os.listdir(save_path) + expected_files = [file + f"{batch_size}.ts" for file in self.file_prefix_names] + for file in expected_files: + if file not in file_list: + raise ValueError(f"Expected file name is {file}, but can't be found in path: {save_path}") + + def prepare_inputs_for_generation( + self, + decoder_input_ids, + past_key_values=None, + encoder_outputs=None, + **kwargs + ): + if past_key_values is not None: + past_key_values_shape = past_key_values[0].shape # B N S D + past_length = past_key_values_shape[-2] + decoder_input_ids_shape = decoder_input_ids.shape + # Some generation methods already pass only the last input ID + if decoder_input_ids_shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = decoder_input_ids_shape[1] - 1 + + decoder_input_ids = decoder_input_ids[:, remove_prefix_length:] + + return { + "encoder_outputs": encoder_outputs, + "past_key_values": past_key_values, + "decoder_input_ids": decoder_input_ids + } + + +class AishellDataset(Dataset): + def __init__(self, audio_dir=None, trans_file = None, sampling_rate = None, *args, **kwargs): + super(AishellDataset, self).__init__() + self.root_dir = audio_dir + self.dataset_info_list = [] + self.label_file = trans_file + self.data_label_map = {} + self.sampling_rate = sampling_rate + self.DURATION = 36108.92 + self.audio_2_trans() + + + def get_total_duration(self): + return self.DURATION + + def data_2_label(self): + with open(self.label_file, 'r', encoding='utf-8') as file: + for line in file: + name = line.strip().split(" ")[0] + trans = "".join(x for x in line.strip().split(" ")[1:]) + self.data_label_map[name] = trans + + def audio_2_trans(self): + self.data_2_label() + for root, _, files in os.walk(self.root_dir): + for filename in files: + filepath = os.path.join(root, filename) + trans = self.data_label_map[filename.split(".")[0]] + self.dataset_info_list.append((filepath, trans)) + + def __len__(self): + return len(self.dataset_info_list) + + def __getitem__(self, idx): + audio_list = [] + transcript_list = [] + + for audio_path, trans in self.dataset_info_list[idx]: + audio, _ = librosa.load(audio_path, sr=self.sampling_rate) + audio_list.append(audio) + transcript_list.append(trans) + return audio_list, transcript_list + + +class MindiePipeline: + + def __init__(self, model:MindieWhisperForConditionalGeneration, device, save_path, batch_size, *args, **kwargs): + self.model = model + self.device = device + if not isinstance(model, MindieWhisperForConditionalGeneration): + raise ValueError(f"Please provide MindieWhisperForConditionalGeneration, found {type(model)}") + + if not (save_path and batch_size): + raise ValueError(f"Please provide compiled model save path and batch_size.") + + self.model.load_mindie_models(save_path, batch_size) + + def __call__(self, input_features, generate_kwargs={}): + inputs = input_features.to(self.device) + generate_kwargs["encoder_outputs"] = self.model.mindie_encoder(inputs)[0] + tokens = self.model.generate(attention_mask=None, **generate_kwargs) + return tokens + + +def DataLoader(dataset:AishellDataset, batch_size=1, drop_last=True): + total_length = len(dataset) + num_batch = total_length // batch_size + start_index = 0 + for _ in range(num_batch): + end_index = min(start_index + batch_size, total_length) + yield dataset[start_index:end_index] + start_index = end_index + if drop_last and num_batch * batch_size < total_length: + return \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/audio/whisper_large_v3/readme.md b/MindIE/MindIE-Torch/built-in/audio/whisper_large_v3/readme.md new file mode 100644 index 0000000000000000000000000000000000000000..3b3d26ed4a6a5bb0ae2da0b70898624112dd3912 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/audio/whisper_large_v3/readme.md @@ -0,0 +1,144 @@ +# Whisper-large-v3模型-推理指导 + +- [概述](#概述) +- [推理环境准备](#推理环境准备) +- [快速上手](#快速上手) + - [获取源码](#获取源码) + - [模型推理](#模型推理) + +# 概述 + +该工程使用mindietorch部署whisper-large-v3模型 + + +# 推理环境准备 + +- 该模型需要以下插件与驱动 + + **表 1** 版本配套表 + + | 配套 | 版本 | 环境准备指导 | + |-----------| ------- | ------------ | + | Python | 3.10.13 | - | + | torch | 2.1.0+cpu | - | + | torch_audio | 2.1.0+cpu | - | + | CANN | 8.0.RC3 | - | + | MindIE | 1.0.RC3 | - | + +# 快速上手 +## 获取源码 + +1. 安装mindie包 + + ```bash + # 安装mindie + chmod +x ./Ascend-mindie_xxx.run + ./Ascend-mindie_xxx.run --install + source /usr/local/Ascend/mindie/set_env.sh + ``` + + +2. 模型权重下载路径: + ```bash + https://huggingface.co/openai/whisper-large-v3/tree/main + ``` + 将权重文件存放至当前目录下的model_path文件夹,请先创建改文件夹。 + + +5. 安装依赖 + ``` + pip3 install torch==2.1.0 --index-url https://download.pytorch.org/whl/cpu + pip3 install nltk + pip3 install librosa + pip3 install transformers==4.36.0 + pip3 install numpy==1.26.0 + ``` + +6. 补丁 + ``` + cd 到transformers的安装路径,找到modeling_whisper.py文件 + 将408行的 past_key_value = (key_states, value_states) + 修改为: past_key_value = (key_states.to(torch.float16), value_states.to(torch.float16)) + ``` + + + +## 模型推理 +1. 设置mindie内存池上限为32,执行如下命令设置环境变量。内存池设置过小,内存重复申请和释放会影响性能。 + ``` + export TORCH_AIE_NPU_CACHE_MAX_SIZE=32 + ``` + +2. 模型编译和推理 + 执行下述命令进行模型编译和推理,sample.py脚本配套使用的数据集链接https://openslr.org/33 + 下载得到data_aishell.tgz。如需使用其他的数据集,请修改对应的数据集解析方式。 + ```bash + tar -xzf data_aishell.tgz + cd ./data_aishell/wav + find . -type f -name "*.tar.gz" -exec tar -xzf {} \ + ``` + 解压目录结构 + ``` + data_aishell + |---transcript + | |--aishell_transcript_v0.8.txt # 标签数据 + |---wav + | |--S0*.tar.gz + | |--test + | |--dev + | |--train + ``` + + ```bash + python3 sample.py \ + -model_path ./model_path \ + -audio_dir ./data_aishell/wav/test \ + -trans_file ./data_aishell/transcript/aishell_transcript_v0.8.txt \ + -bs 8 \ + -save_path ./compiled_models \ + -device_id 0 \ + -soc_version Ascend310P3 + ``` + + 参数说明: + - -model_path:预训练模型路径,必选 + - -audio_dir: 音频数据集文件夹路径,必选 + - -trans_file:音频对应的标签文件,必选 + - -bs:batch_size, 默认值为8, 可选。 + - -save_path: 编译好的模型的保存文件,必选。 + - -device_id: 选在模型运行的卡编号,默认值0,可选。 + - -soc_version: 芯片类型,默认值:Ascend310P3,仅支持配置Ascend310P3或者Ascend910B4。 + + 约束说明: + 1. 当前暂不支持动态batch,batch_size改变后,需要重新编图。 + + 2. 310p3上,性能最佳的batch size为8,batch_size变大,会出现存在显存不够用的情形。 + + 3. 当前只支持推理短音频(单条语音小于30s) + + 4. 仅支持使用MindiePipeline拉起推理流程 + + 评估指标: + 1. 性能指标:bs=8时,转录比(语料总时长/推理总时长)6.3倍。 + + 2. 精度指标:采用错词率作为评估指标,Aishell数据集上的错词率为10.06% + + + +## 模型精度基线 +1. 使用torch—npu运行原始模型,得到精度极限,需要安装torch-npu。 + ``` + pip3 install torch-npu==2.1.0 + ``` +2. 执行test_tools脚本。 + ```bash + python3 test_tools.py \ + -model_path ./model_path \ + -audio_dir ./data_aishell/wav/test \ + -trans_file ./data_aishell/transcript/aishell_transcript_v0.8.txt \ + -bs 1 \ + -device_id 0 + ``` +3. 说明: + 1. 执行脚本耗时5小时左右。 + 2. 错词率为10.58% \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/audio/whisper_large_v3/sample.py b/MindIE/MindIE-Torch/built-in/audio/whisper_large_v3/sample.py new file mode 100644 index 0000000000000000000000000000000000000000..b3a30990852af01fc507aa3412c7c81d6c3d1547 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/audio/whisper_large_v3/sample.py @@ -0,0 +1,57 @@ +import argparse +import time +from transformers import WhisperProcessor +from nltk.metrics.distance import edit_distance + +import mindietorch +from mindie_whisper import MindieWhisperForConditionalGeneration, MindiePipeline, AishellDataset, DataLoader + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('-model_path', type=str, required=True, help="please provide model path.") + parser.add_argument('-audio_dir', type=str, required=True, help="please provide speech samples dir.") + parser.add_argument('-trans_file', type=str, required=True, help="please provide label file.") + parser.add_argument('-bs', type=int, default=8, help="please provide batch_size, default:8.") + parser.add_argument('-soc_version', type=str, choices=["Ascend310P3", "Ascend910B4"], default="Ascend310P3", + help="please provide soc_version, default:Ascend310P3.") + parser.add_argument('-save_path', type=str, default="compiled_models", help="compiled models save dir.") + parser.add_argument('-device_id', type=int, default=0) + + args = parser.parse_args() + device = f"npu:{args.device_id}" + mindie_whisper = MindieWhisperForConditionalGeneration.from_pretrained(args.model_path) + mindie_whisper.compile(args) + mindietorch.set_device(args.device_id) + + processor = WhisperProcessor.from_pretrained(args.model_path) + mindie_pipe = MindiePipeline(mindie_whisper, device, args.save_path, args.bs) + + sampling_rate = processor.feature_extractor.sampling_rate + aishell_dataset = AishellDataset(args.audio_dir, args.trans_file, sampling_rate) + duration = aishell_dataset.get_total_duration() + print(f"speech sample num is {len(aishell_dataset)}, duaration (s) is {duration}") + data_loader = DataLoader(aishell_dataset, batch_size=args.bs) + wer = 0 + step = 0 + for inp, labels in data_loader: + input_features = processor(inp, sampling_rate=sampling_rate, return_tensors="pt").input_features + predicted_ids = mindie_pipe(input_features) + transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True) + for predict, label in zip(transcription, labels): + predict = [w for w in predict] + label = [w for w in label] + wer += edit_distance(predict, label) / len(label) + step += 1 + print(f"predict {transcription}, labels {labels}") + + # reset dataset and dataloader + aishell_dataset = AishellDataset(args.audio_dir, args.trans_file, sampling_rate) + data_loader = DataLoader(aishell_dataset, batch_size=args.bs) + t1 = time.time() + for inp, labels in data_loader: + input_features = processor(inp, sampling_rate=sampling_rate, return_tensors="pt").input_features + predicted_ids = mindie_pipe(input_features) + transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True) + t2 = time.time() + print(f"word-error-rate {wer/step}, performence { duration/(t2 - t1)}") \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/audio/whisper_large_v3/test_tools.py b/MindIE/MindIE-Torch/built-in/audio/whisper_large_v3/test_tools.py new file mode 100644 index 0000000000000000000000000000000000000000..ae1f8e6755638266629dd8794a2985706d6beb34 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/audio/whisper_large_v3/test_tools.py @@ -0,0 +1,46 @@ +import argparse +import time +from typing import Optional, Tuple, Union, List, Dict +from collections import OrderedDict + +import torch +import torch_npu +from nltk.metrics.distance import edit_distance +from transformers import WhisperProcessor, WhisperForConditionalGeneration +from mindie_whisper import AishellDataset, DataLoader + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('-model_path', type=str, required=True, help="please provide model path.") + parser.add_argument('-audio_dir', type=str, required=True, help="please provide speech samples dir.") + parser.add_argument('-trans_file', type=str, required=True, help="please provide label file.") + parser.add_argument('-bs', type=int, default=8, help="please provide batch_size, default:8.") + parser.add_argument('-soc_version', type=str, choices=["Ascend310P3", "Ascend910B4"], default="Ascend310P3", + help="please provide soc_version, default:Ascend310P3.") + parser.add_argument('-device_id', type=int, default=0) + args = parser.parse_args() + device = f"npu:{args.device_id}" + torch_npu.npu.set_device(args.device_id) + whisper = WhisperForConditionalGeneration.from_pretrained(args.model_path).to(device) + processor = WhisperProcessor.from_pretrained(args.model_path) + + sampling_rate = processor.feature_extractor.sampling_rate + aishell_dataset = AishellDataset(args.audio_dir, args.trans_file, sampling_rate) + duration = aishell_dataset.get_total_duration() + print(f"speech sample num is {len(aishell_dataset)}, duaration (s) is {duration}") + data_loader = DataLoader(aishell_dataset, batch_size=args.bs) + wer = 0 + step = 0 + time_cost = time.time() + for inp, labels in data_loader: + input_features = processor(inp, sampling_rate=sampling_rate, return_tensors="pt").input_features + predicted_ids = whisper.generate(input_features.to(device)) + transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True) + time_cost += time.time() + for predict, label in zip(transcription, labels): + predict = [w for w in predict] + label = [w for w in label] + wer += edit_distance(predict, label) / len(label) + step += 1 + print(f"predict {transcription}, labels {labels}") + print(f"torch_npu: word-error-rate {wer/step}, performence { duration/time_cost}")