From ed5112120a5d31b227f45740eebd72416c5bdcfc Mon Sep 17 00:00:00 2001 From: tiansongxue Date: Tue, 3 Sep 2024 17:59:52 +0800 Subject: [PATCH 01/44] check init past key value check init past_key_value call init past key value once call remove prefill decoder remove load prefill decoder remove 910 310 mindie-whisper-large-v3 --- .../audio/whisper_large_v3/mindie_whisper.py | 695 ++++++++++++++++++ .../built-in/audio/whisper_large_v3/readme.md | 134 ++++ .../built-in/audio/whisper_large_v3/sample.py | 53 ++ .../audio/whisper_large_v3/test_tools.py | 46 ++ 4 files changed, 928 insertions(+) create mode 100644 MindIE/MindIE-Torch/built-in/audio/whisper_large_v3/mindie_whisper.py create mode 100644 MindIE/MindIE-Torch/built-in/audio/whisper_large_v3/readme.md create mode 100644 MindIE/MindIE-Torch/built-in/audio/whisper_large_v3/sample.py create mode 100644 MindIE/MindIE-Torch/built-in/audio/whisper_large_v3/test_tools.py diff --git a/MindIE/MindIE-Torch/built-in/audio/whisper_large_v3/mindie_whisper.py b/MindIE/MindIE-Torch/built-in/audio/whisper_large_v3/mindie_whisper.py new file mode 100644 index 0000000000..e0aeb3634b --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/audio/whisper_large_v3/mindie_whisper.py @@ -0,0 +1,695 @@ +import argparse +import os +import copy +import math +import warnings +import librosa +from typing import Optional, Tuple, Union, List, Dict +from collections import OrderedDict + +import torch +from torch.utils.data import Dataset +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn +from transformers import WhisperForConditionalGeneration +from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask +from transformers.models.whisper.modeling_whisper import WhisperDecoder,WhisperModel,WhisperDecoderLayer +from transformers.generation.logits_process import LogitsProcessorList +from transformers.generation.stopping_criteria import StoppingCriteriaList, validate_stopping_criteria +from transformers.generation.utils import GenerationMixin +import mindietorch +from mindietorch._enums import dtype + + +class MindieWhisperDecoderLayer(WhisperDecoderLayer): + + def __init__(self, config): + super().__init__(config) + self.embed_dim = config.d_model + self._init_past_key_values() + + def forward( + self, + hidden_states, + attention_mask, + encoder_hidden_states, + layer_head_mask, + cross_attn_layer_head_mask, + past_key_value, + cross_attn_past_key_value, + output_attentions, + use_cache, + encoder_attention_mask=None + ) -> torch.Tensor: + residual = hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + + # Self Attention + self_attn_past_key_value = past_key_value if past_key_value is not None else None + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + past_key_value=self_attn_past_key_value, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + output_attentions=output_attentions, + ) + hidden_states = residual + hidden_states + + # Cross-Attention Block + cross_attn_present_key_value = None + cross_attn_weights = None + if encoder_hidden_states is not None: + residual = hidden_states + hidden_states = self.encoder_attn_layer_norm(hidden_states) + + hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn( + hidden_states=hidden_states, + key_value_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + layer_head_mask=cross_attn_layer_head_mask, + past_key_value=cross_attn_past_key_value, + output_attentions=output_attentions, + ) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + + # add cross-attn to positions 3,4 of present_key_value tuple + present_key_value = present_key_value + cross_attn_present_key_value + + # Fully Connected + residual = hidden_states + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = hidden_states.reshape(-1, self.embed_dim) + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = self.fc2(hidden_states) + hidden_states = hidden_states.reshape(-1, 1, self.embed_dim) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +class MindieWhisperDecoder(WhisperDecoder): + + def __init__(self, config): + super().__init__(config) + self.layers = nn.ModuleList([MindieWhisperDecoderLayer(config) for _ in range(config.decoder_layers)]) + self.config = config + + def forward( + self, + input_ids, + encoder_hidden_states, + past_key_values, + use_cache=True, + attention_mask=None): + if input_ids is None: + raise ValueError("You have to specify either decoder_input_ids") + inputs_embeds = self.embed_tokens(input_ids) + + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + past_key_values_length = past_key_values[0].shape[2] if past_key_values is not None else 0 + + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, input_shape, inputs_embeds, past_key_values_length + ) + inputs_embeds = self.embed_tokens(input_ids) + # embed positions + positions = self.embed_positions(input_ids, past_key_values_length=past_key_values_length) + hidden_states = inputs_embeds + positions + + past_key_value_cache = [] + for idx, decoder_layer in enumerate(self.layers): + + past_key_value = (past_key_values[4*idx], past_key_values[4*idx + 1])\ + if past_key_values is not None else None + cross_past_key_value = (past_key_values[4*idx + 2], past_key_values[4*idx + 3])\ + if past_key_values is not None else None + + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + layer_head_mask=None, + cross_attn_layer_head_mask=None, + past_key_value=past_key_value, + cross_attn_past_key_value=cross_past_key_value, + output_attentions=None, + use_cache=use_cache) + + hidden_states = layer_outputs[0] + past_key_value_cache.extend(layer_outputs[1]) + + hidden_states = self.layer_norm(hidden_states) + return hidden_states, past_key_value_cache + + + +class MindieWhisperModel(WhisperModel): + + def __init__(self, config): + super().__init__(config) + self.decoder = MindieWhisperDecoder(config) + + def forward( + self, + encoder_outputs, + decoder_input_ids, + past_key_values, + use_cache: Optional[bool] = False, + return_dict: Optional[bool] = True, + input_features: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None + ) -> List[torch.Tensor]: + if input_features is None and encoder_outputs is None: + raise ValueError("Make sure to provide either `input_features` or `encoder_outputs` to `forward`.") + + if encoder_outputs is None: + input_features = self._mask_input_features(input_features, attention_mask=attention_mask) + + encoder_outputs = self.encoder( + input_features, + head_mask=None, + output_attentions=None, + output_hidden_states=None, + return_dict=return_dict, + ) + + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + encoder_hidden_states=encoder_outputs, + past_key_values=past_key_values, + use_cache=use_cache + ) + return decoder_outputs + + +class MindieWhisperForConditionalGeneration(WhisperForConditionalGeneration, GenerationMixin): + + def __init__(self, config): + super().__init__(config) + self.model = MindieWhisperModel(config) + self.has_load = False + self.has_compile = False + self.mindie_encoder = None + self.mindie_decoder_prefill = None + self.mindie_decoder = None + self.save_path = None + self.batch_size = -1 + self.encoder_seq_len = 1500 + self.init_encoder_seq_len = self.encoder_seq_len - 1 + + + self.file_prefix_names = ["mindie_whisper_encoder_bs", + "mindie_whisper_decoder_bs" + ] + + + def compile(self, args): + print("Start compiling Mindie-Whisper, it will take some time, please wait.") + save_path = args.save_path + model_path = args.model_path + + if not save_path: + raise ValueError("Please provide the directory where the compiled model saved.") + if not os.path.exists(save_path): + os.makedirs(save_path) + print(f"Directory {save_path} created.") + else: + print(f"Directory {save_path} already exists.") + + model = MindieWhisperForConditionalGeneration.from_pretrained(model_path) + encoder = model.get_encoder() + + class Encoder(torch.nn.Module): + + def __init__(self, model): + super().__init__() + self.model = model + + def forward(self, input_features): + return self.model(input_features=input_features, return_dict=False) + + mel_feature_size = 128 + max_frames = 3000 + input_features = torch.randn([args.bs, mel_feature_size, max_frames]) + encoder_traced = torch.jit.trace(Encoder(encoder), (input_features)) + input_info = [mindietorch.Input(shape=(args.bs, mel_feature_size, max_frames))] + print("Start compiling encoder.") + compiled = mindietorch.compile(encoder_traced, + inputs=input_info, + precision_policy=mindietorch.PrecisionPolicy.FP16, + soc_version=args.soc_version + ) + save_file = os.path.join(save_path, f"{self.file_prefix_names[0]}{args.bs}.ts") + torch.jit.save(compiled, save_file) + + traced_file = os.path.join(save_path, f"{self.file_prefix_names[0]}{args.bs}.pt") + torch.jit.save(encoder_traced, traced_file) + print(f"Compile encoder success, saved in {save_file}") + + # print("Start compiling prefill_decoder.") + # encoder_seq_len = 1500 + # hidden_size = 1280 + # encoder_outputs = torch.randn([args.bs, encoder_seq_len, hidden_size]) + # decoder_input_ids = torch.randint(1, 4, (args.bs, 1)) + # prefill_decoder_traced = torch.jit.trace(model.eval(), (decoder_input_ids, encoder_outputs)) + # input_info = [mindietorch.Input(shape=(args.bs, 1), dtype=dtype.INT64), + # mindietorch.Input(shape=(args.bs, encoder_seq_len, hidden_size))] + # prefill_decoder_compiled = mindietorch.compile(prefill_decoder_traced, + # inputs=input_info, + # precision_policy=mindietorch.PrecisionPolicy.FP16, + # soc_version=args.soc_version) + + # save_file = os.path.join(save_path, f"{self.file_prefix_names[1]}{args.bs}.ts") + # torch.jit.save(prefill_decoder_compiled, save_file) + + # traced_file = os.path.join(save_path, f"{self.file_prefix_names[1]}{args.bs}.pt") + # torch.jit.save(prefill_decoder_traced, traced_file) + + # print(f"Compile prefill_decoder success, saved in {save_file}.") + + print("Start compiling decoder.") + hidden_size = 1280 + head_num = 20 + head_size = 64 + seq_len = 1 + layer_nums = 32 + max_len = 448 + + encoder_outputs = torch.randn([args.bs, self.encoder_seq_len, hidden_size], dtype=torch.float16) + decoder_input_ids = torch.randint(1, 4, (args.bs, 1)) + all_past_key_value = [torch.randn([args.bs, head_num, seq_len, head_size], dtype=torch.float16), + torch.randn([args.bs, head_num, seq_len, head_size], dtype=torch.float16), + torch.randn([args.bs, head_num, self.encoder_seq_len, head_size], dtype=torch.float16), + torch.randn([args.bs, head_num, self.encoder_seq_len, head_size], dtype=torch.float16)] * layer_nums + traced_args = [decoder_input_ids, encoder_outputs] + traced_args.extend(all_past_key_value) + traced_decoder = torch.jit.trace(model.eval(), example_inputs=traced_args) + + key_value_infos = [mindietorch.Input( + min_shape=(args.bs, head_num, 0, head_size), + max_shape=(args.bs, head_num, max_len, head_size) + ), + mindietorch.Input( + min_shape=(args.bs, head_num, seq_len, head_size), + max_shape=(args.bs, head_num, max_len, head_size) + ), + mindietorch.Input( + min_shape=(args.bs, head_num, self.init_encoder_seq_len, head_size), + max_shape=(args.bs, head_num, self.encoder_seq_len, head_size) + ), + mindietorch.Input( + min_shape=(args.bs, head_num, self.init_encoder_seq_len, head_size), + max_shape=(args.bs, head_num, self.encoder_seq_len, head_size) + ) + ] * layer_nums + input_info = [mindietorch.Input(min_shape=(args.bs, 1), max_shape=(args.bs, 1), dtype=dtype.INT64), + mindietorch.Input(min_shape=(args.bs, self.encoder_seq_len), max_shape=(args.bs, self.encoder_seq_len))] + input_info.extend(key_value_infos) + buffer_size = math.ceil((args.bs * self.encoder_seq_len * hidden_size * 4) / 1024 / 1024) + print(f"Set {buffer_size}/MB for each output.") + compiled_decoder = mindietorch.compile(traced_decoder, + inputs=input_info, + precision_policy=mindietorch.PrecisionPolicy.FP16, + soc_version=args.soc_version, + default_buffer_size_vec=[buffer_size]) + + save_file = os.path.join(save_path, f"{self.file_prefix_names[2]}{args.bs}.ts") + torch.jit.save(compiled_decoder, save_file) + + traced_file = os.path.join(save_path, f"{self.file_prefix_names[2]}{args.bs}.pt") + torch.jit.save(traced_decoder, traced_file) + + print(f"Compile whisper_decoder success, saved in {save_file}.") + self.has_compile = True + + def load_mindie_models(self, save_path, batch_size): + if not (save_path and batch_size): + raise ValueError(f"Please provide batch_size and the directory where the compiled models saved,\ + but found save_path is {save_path}, batch_size is{batch_size}.") + self._check_save_path(save_path, batch_size) + self.batch_size = batch_size + + if not self.has_load: + self.mindie_encoder = torch.jit.load(f"{save_path}/{self.file_prefix_names[0]}{batch_size}.ts") + print(f"load {self.file_prefix_names[0]}{batch_size}.ts success.") + + # self.mindie_decoder_prefill = torch.jit.load(f"{save_path}/{self.file_prefix_names[1]}{batch_size}.ts") + # print(f"load {self.file_prefix_names[1]}{batch_size}.ts success.") + + self.mindie_decoder = torch.jit.load(f"{save_path}/{self.file_prefix_names[2]}{batch_size}.ts") + print(f"load {self.file_prefix_names[2]}{batch_size}.ts success.") + self.has_load = True + else: + print("Mindie whisper has already load.") + + def forward(self, *args): + """For performance reasons, all layers'past_key_value need to be used as inputs and outputs.""" + if len(args) not in (2, 130): + raise ValueError(f"The args length of forward can only be 2 or 130, but got {len(args)}") + decoder_input_ids = args[0] + encoder_outputs = args[1] + past_key_values = args[2:] if len(args) == 130 else None + outputs = self.model( + decoder_input_ids=decoder_input_ids, + encoder_outputs=encoder_outputs, + past_key_values=past_key_values, + use_cache=True, + return_dict=False, + input_features=None + ) + lm_logits = self.proj_out(outputs[0]) + return [lm_logits] + outputs[1] + + def greedy_search( + self, + input_ids: torch.LongTensor, + logits_processor: Optional[LogitsProcessorList] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, + max_length: Optional[int] = None, + pad_token_id: Optional[int] = None, + eos_token_id: Optional[Union[int, List[int]]] = None, + output_scores: Optional[bool] = None, + return_dict_in_generate: Optional[bool] = None, + synced_gpus: bool = False, + **model_kwargs): + logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() + stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() + if max_length is not None: + print( + "`max_length` is deprecated in this function, use" + " `stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])` instead." + ) + stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) + pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id + eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id + if isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] + eos_token_id_tensor = torch.tensor(eos_token_id).to("cpu") if eos_token_id is not None else None + output_scores = output_scores if output_scores is not None else self.generation_config.output_scores + # init attention / hidden states / scores tuples + scores = () if (return_dict_in_generate and output_scores) else None + + # keep track of which sequences are already finished + unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device="cpu") + + this_peer_finished = False # used by synced_gpus only + + is_first_step = True + while True: + + model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) + args = [model_inputs["decoder_input_ids"].contiguous().to("npu"), model_inputs["encoder_outputs"]] + if is_first_step: + args.extend(self._get_init_past_key_values()) + is_first_step = False + else: + args.extend(model_inputs["past_key_values"]) + outputs = self.mindie_decoder(*args) + if synced_gpus and this_peer_finished: + continue + + next_token_logits = outputs[0].to("cpu")[:, -1, :] + + # pre-process distribution + next_tokens_scores = logits_processor(input_ids.to("cpu"), next_token_logits) + + # argmax + next_tokens = torch.argmax(next_tokens_scores, dim=-1) + + # finished sentences should have their next token be a padding token + if eos_token_id is not None: + if pad_token_id is None: + raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.") + next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) + + # update generated ids, model inputs, and length for next step + input_ids = torch.cat([input_ids.to("cpu"), next_tokens[:, None]], dim=-1) + + model_kwargs = self._update_model_kwargs_for_generation( + outputs, model_kwargs + ) + + # if eos_token was found in one sentence, set sentence to finished + if eos_token_id_tensor is not None: + unfinished_sequences = unfinished_sequences.mul( + next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0) + ) + + # stop when each sentence is finished + if unfinished_sequences.max() == 0: + this_peer_finished = True + + # stop if we exceed the maximum length + if stopping_criteria(input_ids, scores): + this_peer_finished = True + + if this_peer_finished and not synced_gpus: + break + + return input_ids + + def generate( + self, + input_features: Optional[torch.Tensor] = None, + generation_config=None, + logits_processor=None, + stopping_criteria=None, + prefix_allowed_tokens_fn=None, + synced_gpus=False, + return_dict_in_generate: Optional[bool] = None, + **kwargs, + ): + """ + 1 only support speech < 30s + 2 only support encoder outputs as input + """ + if generation_config is None: + generation_config = copy.deepcopy(self.generation_config) + is_shortform = True + if is_shortform: + outputs = GenerationMixin.generate( + self, + input_features, + generation_config, + logits_processor, + stopping_criteria, + prefix_allowed_tokens_fn, + synced_gpus, + return_dict_in_generate=return_dict_in_generate, + **kwargs + ) + return outputs + + def _update_model_kwargs_for_generation( + self, + outputs, + model_inputs, + ): + # update past_key_values + model_inputs["past_key_values"] = outputs[1:] + return model_inputs + + def _maybe_initialize_input_ids_for_generation( + self, + inputs: Optional[torch.Tensor] = None, + bos_token_id: Optional[int] = None, + model_kwargs: Optional[Dict[str, torch.Tensor]] = None, + ) -> torch.LongTensor: + """Initializes input ids for generation, if necessary.""" + if inputs is not None: + return inputs + + encoder_outputs = model_kwargs.get("encoder_outputs") + if self.config.is_encoder_decoder and encoder_outputs is not None: + # make dummy input_ids with value -100, as a sanity check ensuring that they won't be used for encoding + shape = encoder_outputs.size()[:-1] + input_ids = torch.ones(shape, dtype=torch.long, device="cpu") * -100 + return input_ids.to("npu") + + if bos_token_id is None: + raise ValueError("`bos_token_id` has to be defined when no `input_ids` are provided.") + + # If there is some tensor in `model_kwargs`, we can infer the batch size from it. This is helpful with + # soft-prompting or in multimodal implementations built on top of decoder-only language models. + batch_size = 1 + for value in model_kwargs.values(): + if isinstance(value, torch.Tensor): + batch_size = value.shape[0] + break + return torch.ones((batch_size, 1), dtype=torch.long, device=self.device) * bos_token_id + + def _prepare_decoder_input_ids_for_generation( + self, + batch_size: int, + model_input_name: str, + model_kwargs: Dict[str, torch.Tensor], + decoder_start_token_id: int = None, + bos_token_id: int = None, + device: torch.device = None, + ) -> Tuple[torch.LongTensor, Dict[str, torch.Tensor]]: + """Prepares `decoder_input_ids` for generation with encoder-decoder models""" + # 1. Check whether the user has defined `decoder_input_ids` manually. To facilitate in terms of input naming, + # we also allow the user to pass it under `input_ids`, if the encoder does not use it as the main input. + if model_kwargs is not None and "decoder_input_ids" in model_kwargs: + decoder_input_ids = model_kwargs.pop("decoder_input_ids") + elif "input_ids" in model_kwargs and model_input_name != "input_ids": + decoder_input_ids = model_kwargs.pop("input_ids") + else: + decoder_input_ids = None + + # 2. Encoder-decoder models expect the `decoder_input_ids` to start with a special token. Let's ensure that. + decoder_start_token_id = self._get_decoder_start_token_id(decoder_start_token_id, bos_token_id) + if device is None: + device = self.device + decoder_input_ids_start = torch.ones((batch_size, 1), dtype=torch.long, device="cpu") * decoder_start_token_id + + # no user input -> use decoder_start_token_id as decoder_input_ids + if decoder_input_ids is None: + decoder_input_ids = decoder_input_ids_start + # exception: Donut checkpoints have task-specific decoder starts and don't expect a BOS token + elif self.config.model_type == "vision-encoder-decoder" and "donut" in self.name_or_path.lower(): + pass + # user input but doesn't start with decoder_start_token_id -> prepend decoder_start_token_id (and adjust + # decoder_attention_mask if provided) + elif (decoder_input_ids[:, 0] != decoder_start_token_id).all().item(): + decoder_input_ids = torch.cat([decoder_input_ids_start, decoder_input_ids], dim=-1) + if "decoder_attention_mask" in model_kwargs: + decoder_attention_mask = model_kwargs["decoder_attention_mask"] + decoder_attention_mask = torch.cat( + (torch.ones_like(decoder_attention_mask)[:, :1], decoder_attention_mask), + dim=-1, + ) + model_kwargs["decoder_attention_mask"] = decoder_attention_mask.to(self.device) + + return decoder_input_ids.to("npu"), model_kwargs + + def _check_save_path(self, save_path, batch_size): + file_list = os.listdir(save_path) + expected_files = [file + f"{batch_size}.ts" for file in self.file_prefix_names] + for file in expected_files: + if file not in file_list: + raise ValueError(f"Expected file name is {file}, but can't be found in path: {save_path}") + + def _init_past_key_values(self): + init_seq_len = 0 + layer_nums = 32 + head_num = 20 + head_size = 64 + + self.init_past_key_values = [torch.randn([self.batch_size, head_num, init_seq_len, head_size]).to(torch.float16).to("npu"), + torch.randn([self.batch_size, head_num, init_seq_len, head_size]).to(torch.float16).to("npu"), + torch.randn([self.batch_size, head_num, init_seq_len, head_size]).to(torch.float16).to("npu"), + torch.randn([self.batch_size, head_num, init_seq_len, head_size]).to(torch.float16).to("npu")] * layer_nums + print("init past key value sucess.") + + def _get_init_past_key_values(self): + return self.init_past_key_values + + def prepare_inputs_for_generation( + self, + decoder_input_ids, + past_key_values=None, + encoder_outputs=None, + **kwargs + ): + if past_key_values is not None: + past_key_values_shape = past_key_values[0].shape + past_length = past_key_values_shape[-2] + decoder_input_ids_shape = decoder_input_ids.shape + # Some generation methods already pass only the last input ID + if decoder_input_ids_shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = decoder_input_ids_shape[1] - 1 + + decoder_input_ids = decoder_input_ids[:, remove_prefix_length:] + + return { + "encoder_outputs": encoder_outputs, + "past_key_values": past_key_values, + "decoder_input_ids": decoder_input_ids + } + + +class AishellDataset(Dataset): + def __init__(self, audio_dir=None, trans_file = None, sampling_rate = None, *args, **kwargs): + super(AishellDataset, self).__init__() + self.root_dir = audio_dir + self.dataset_info_list = [] + self.label_file = trans_file + self.data_label_map = {} + self.sampling_rate = sampling_rate + self.DURATION = 49762.8 + self.audio_2_trans() + + + def get_total_duration(self): + return self.DURATION + + def data_2_label(self): + with open(self.label_file, 'r', encoding='utf-8') as file: + for line in file: + name = line.strip().split(" ")[0] + trans = "".join(x for x in line.split(" ")[1:-1]) + self.data_label_map[name] = trans + + def audio_2_trans(self): + self.data_2_label() + for root, _, files in os.walk(self.root_dir): + for filename in files: + filepath = os.path.join(root, filename) + trans = self.data_label_map[filename.split(".")[0]] + self.dataset_info_list.append((filepath, trans)) + + def __len__(self): + return len(self.dataset_info_list) + + def __getitem__(self, idx): + audio_list = [] + transcript_list = [] + + for audio_path, trans in self.dataset_info_list[idx]: + audio, _ = librosa.load(audio_path, sr=self.sampling_rate) + audio_list.append(audio) + transcript_list.append(trans) + return audio_list, transcript_list + + +def DataLoader(dataset:AishellDataset, batch_size=1, drop_last=True): + total_length = len(dataset) + num_batch = total_length // batch_size + start_index = 0 + for _ in range(num_batch): + end_index = min(start_index + batch_size, total_length) + yield dataset[start_index:end_index] + start_index = end_index + if drop_last and num_batch * batch_size < total_length: + return + +class MindiePipeline: + + def __init__(self, model:MindieWhisperForConditionalGeneration, device, save_path, batch_size, *args, **kwargs): + self.model = model + self.device = device + if not isinstance(model, MindieWhisperForConditionalGeneration): + raise ValueError(f"Please provide MindieWhisperForConditionalGeneration, found {type(model)}") + + if not (save_path and batch_size): + raise ValueError(f"Please provide compiled model save path and batch_size.") + + self.model.load_mindie_models(save_path, batch_size) + + def __call__(self, input_features, generate_kwargs={}): + inputs = input_features.to(self.device) + generate_kwargs["encoder_outputs"] = self.model.mindie_encoder(inputs)[0] + tokens = self.model.generate(attention_mask=None, **generate_kwargs) + return tokens diff --git a/MindIE/MindIE-Torch/built-in/audio/whisper_large_v3/readme.md b/MindIE/MindIE-Torch/built-in/audio/whisper_large_v3/readme.md new file mode 100644 index 0000000000..2e6693e475 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/audio/whisper_large_v3/readme.md @@ -0,0 +1,134 @@ +# Whisper-large-v3模型-推理指导 + +- [概述](#概述) +- [推理环境准备](#推理环境准备) +- [快速上手](#快速上手) + - [获取源码](#获取源码) + - [模型推理](#模型推理) + +# 概述 + +该工程使用mindietorch部署whisper-large-v3模型 + + +# 推理环境准备 + +- 该模型需要以下插件与驱动 + + **表 1** 版本配套表 + + | 配套 | 版本 | 环境准备指导 | + | ------ | ------- | ------------ | + | Python | 3.10.13 | - | + | torch | 2.1.0+cpu | - | + | torch_audio | 2.1.0+cpu | - | + | CANN | 8.0.RC2 | - | + | MindIE | 1.0.RC2.B091 | - | + +# 快速上手 +## 获取源码 + +1. 安装mindie包 + + ```bash + # 安装mindie + chmod +x ./Ascend-mindie_xxx.run + ./Ascend-mindie_xxx.run --install + source /usr/local/Ascend/mindie/set_env.sh + ``` + + +2. 模型权重下载路径: + ```bash + https://huggingface.co/openai/whisper-large-v3/tree/main + ``` + 将权重文件存放至当前目录下的mode_path文件夹,请先创建改文件夹。 + + +5. 安装依赖 + ``` + pip3 install torch==2.1.0 --index-url https://download.pytorch.org/whl/cpu + pip3 install nltk + pip3 install librosa + pip3 install transformers==4.36.0 + pip3 install numpy==1.26.0 + ``` + +## 模型推理 +1. 设置mindie内存池上限为16,执行如下命令设置环境变量 + ``` + export TORCH_AIE_NPU_CACHE_MAX_SIZE=16 + export ASCENDIE_FASTER_MODE=1 + ``` + +2. 模型编译和推理 + 执行下述命令进行模型编译和推理,sample.py脚本配套使用的数据集链接https://openslr.org/33/,下载得到data_aishell.tgz。如需使用其他的数据集,请修改对应的数据集解析方式。 + ```bash + tar -xzf data_aishell.tgz + cd ./data_aishell/wav + find . -type f -name "*.tar.gz" -exec tar -xzf {} \ + ``` + 解压目录结构 + ``` + data_aishell + |---transcript + | |--aishell_transcript_v0.8.txt # 标签数据 + |---wav + | |--S0*.tar.gz + | |--test # 测试语料,约10小时 + | |--dev + | |--train + ``` + + ```bash + python3 sample.py \ + -model_path ./model_path \ + -audio_dir ./data_aishell/wav/test \ + -trans_file ./data_aishell/transcript/aishell_transcript_v0.8.txt \ + -bs 8 \ + -save_path ./compiled_models \ + -device_id 0 \ + -soc_version Ascend310P3 + ``` + + 参数说明: + - -model_path:预训练模型路径,必选 + - -audio_dir: 音频数据集文件夹路径,必选 + - -trans_file:音频对应的标签文件,必选 + - -bs:batch_size, 默认值为8, 可选。 + - -save_path: 编译好的模型的保存文件,必选。 + - -device_id: 选在模型运行的卡编号,默认值0,可选。 + - -soc_version: 芯片类型,默认值:Ascend310P3,仅支持配置Ascend310P3或者Ascend910B4。 + + 约束说明: + 1. 当前暂不支持动态batch,batch_size改变后,需要重新编图。 + + 2. 310p3上,性能最佳的batch size为8,batch_size变大,会出现存在显存不够用的情形。 + + 3. 当前只支持推理短音频(单条语音小于30s) + + 4. 仅支持使用MindiePipeline拉起推理流程 + + 评估指标: + 1. 性能指标:bs=8(当前最优,带持续优化)时,推理aishell数据集需要5200+秒,转录比(语料总时长/推理总时长)6.8倍。 + + 2. 精度指标:采用错词率作为评估指标,Aishell数据集上的错词率为10.06% + + +## 模型精度基线 +1. 使用torch—npu运行原始模型,得到精度极限,需要安装torch-npu。 + ``` + pip3 install torch-npu==2.1.0 + ``` +2. 执行test_tools脚本。(改脚本不作为交付件,开发和自测对齐精度使用) + ```bash + python3 test_tools.py \ + -model_path ./model_path \ + -audio_dir ./data_aishell/wav/test \ + -trans_file ./data_aishell/transcript/aishell_transcript_v0.8.txt \ + -bs 1 \ + -device_id 0 + ``` +3. 说明: + 1. 执行脚本耗时5小时左右。 + 2. 错词率为10.58% \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/audio/whisper_large_v3/sample.py b/MindIE/MindIE-Torch/built-in/audio/whisper_large_v3/sample.py new file mode 100644 index 0000000000..06ebb8af24 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/audio/whisper_large_v3/sample.py @@ -0,0 +1,53 @@ +import argparse +import time +from transformers import WhisperProcessor +from nltk.metrics.distance import edit_distance + +import mindietorch +from mindie_whisper import MindieWhisperForConditionalGeneration, MindiePipeline, AishellDataset, DataLoader + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('-model_path', type=str, required=True, help="please provide model path.") + parser.add_argument('-audio_dir', type=str, required=True, help="please provide speech samples dir.") + parser.add_argument('-trans_file', type=str, required=True, help="please provide label file.") + parser.add_argument('-bs', type=int, default=8, help="please provide batch_size, default:8.") + parser.add_argument('-soc_version', type=str, choices=["Ascend310P3", "Ascend910B4"], default="Ascend310P3", + help="please provide soc_version, default:Ascend310P3.") + parser.add_argument('-save_path', type=str, default="compiled_models", help="compiled models save dir.") + parser.add_argument('-device_id', type=int, default=0) + + args = parser.parse_args() + device = f"npu:{args.device_id}" + mindie_whisper = MindieWhisperForConditionalGeneration.from_pretrained(args.model_path) + mindie_whisper.compile(args) + mindietorch.set_device(args.device_id) + + processor = WhisperProcessor.from_pretrained(args.model_path) + mindie_pipe = MindiePipeline(mindie_whisper, device, args.save_path, args.bs) + + sampling_rate = processor.feature_extractor.sampling_rate + aishell_dataset = AishellDataset(args.audio_dir, args.trans_file, sampling_rate) + duration = aishell_dataset.get_total_duration() + print(f"speech sample num is {len(aishell_dataset)}, duaration (s) is {duration}") + data_loader = DataLoader(aishell_dataset, batch_size=args.bs) + wer = 0 + step = 0 + for inp, labels in data_loader: + input_features = processor(inp, sampling_rate=sampling_rate, return_tensors="pt").input_features + predicted_ids = mindie_pipe(input_features) + transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True) + for predict, label in zip(transcription, labels): + predict = [w for w in predict] + label = [w for w in label] + wer += edit_distance(predict, label) / len(label) + step += 1 + print(f"predict {transcription}, labels {labels}") + t1 = time.time() + for inp, labels in data_loader: + input_features = processor(inp, sampling_rate=sampling_rate, return_tensors="pt").input_features + predicted_ids = mindie_pipe(input_features) + transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True) + t2 = time.time() + print(f"word-error-rate {wer/step}, performence { duration/(t2 - t1)}") \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/audio/whisper_large_v3/test_tools.py b/MindIE/MindIE-Torch/built-in/audio/whisper_large_v3/test_tools.py new file mode 100644 index 0000000000..ae1f8e6755 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/audio/whisper_large_v3/test_tools.py @@ -0,0 +1,46 @@ +import argparse +import time +from typing import Optional, Tuple, Union, List, Dict +from collections import OrderedDict + +import torch +import torch_npu +from nltk.metrics.distance import edit_distance +from transformers import WhisperProcessor, WhisperForConditionalGeneration +from mindie_whisper import AishellDataset, DataLoader + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('-model_path', type=str, required=True, help="please provide model path.") + parser.add_argument('-audio_dir', type=str, required=True, help="please provide speech samples dir.") + parser.add_argument('-trans_file', type=str, required=True, help="please provide label file.") + parser.add_argument('-bs', type=int, default=8, help="please provide batch_size, default:8.") + parser.add_argument('-soc_version', type=str, choices=["Ascend310P3", "Ascend910B4"], default="Ascend310P3", + help="please provide soc_version, default:Ascend310P3.") + parser.add_argument('-device_id', type=int, default=0) + args = parser.parse_args() + device = f"npu:{args.device_id}" + torch_npu.npu.set_device(args.device_id) + whisper = WhisperForConditionalGeneration.from_pretrained(args.model_path).to(device) + processor = WhisperProcessor.from_pretrained(args.model_path) + + sampling_rate = processor.feature_extractor.sampling_rate + aishell_dataset = AishellDataset(args.audio_dir, args.trans_file, sampling_rate) + duration = aishell_dataset.get_total_duration() + print(f"speech sample num is {len(aishell_dataset)}, duaration (s) is {duration}") + data_loader = DataLoader(aishell_dataset, batch_size=args.bs) + wer = 0 + step = 0 + time_cost = time.time() + for inp, labels in data_loader: + input_features = processor(inp, sampling_rate=sampling_rate, return_tensors="pt").input_features + predicted_ids = whisper.generate(input_features.to(device)) + transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True) + time_cost += time.time() + for predict, label in zip(transcription, labels): + predict = [w for w in predict] + label = [w for w in label] + wer += edit_distance(predict, label) / len(label) + step += 1 + print(f"predict {transcription}, labels {labels}") + print(f"torch_npu: word-error-rate {wer/step}, performence { duration/time_cost}") -- Gitee From 358dbd8ff42e96f02b4ffd4a51a0d94503d6feb0 Mon Sep 17 00:00:00 2001 From: tiansongxue Date: Mon, 9 Sep 2024 18:01:47 +0800 Subject: [PATCH 02/44] fix call _init_past_key_values --- .../built-in/audio/whisper_large_v3/mindie_whisper.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/MindIE/MindIE-Torch/built-in/audio/whisper_large_v3/mindie_whisper.py b/MindIE/MindIE-Torch/built-in/audio/whisper_large_v3/mindie_whisper.py index e0aeb3634b..6f58086458 100644 --- a/MindIE/MindIE-Torch/built-in/audio/whisper_large_v3/mindie_whisper.py +++ b/MindIE/MindIE-Torch/built-in/audio/whisper_large_v3/mindie_whisper.py @@ -27,8 +27,7 @@ class MindieWhisperDecoderLayer(WhisperDecoderLayer): def __init__(self, config): super().__init__(config) self.embed_dim = config.d_model - self._init_past_key_values() - + def forward( self, hidden_states, @@ -204,12 +203,10 @@ class MindieWhisperForConditionalGeneration(WhisperForConditionalGeneration, Gen self.batch_size = -1 self.encoder_seq_len = 1500 self.init_encoder_seq_len = self.encoder_seq_len - 1 - - self.file_prefix_names = ["mindie_whisper_encoder_bs", "mindie_whisper_decoder_bs" ] - + self._init_past_key_values() def compile(self, args): print("Start compiling Mindie-Whisper, it will take some time, please wait.") -- Gitee From b0566b4ae86a19ce3d64e69eeffea6bbb274c160 Mon Sep 17 00:00:00 2001 From: tiansongxue Date: Mon, 9 Sep 2024 18:05:12 +0800 Subject: [PATCH 03/44] fix _init_past_key_values --- .../audio/whisper_large_v3/mindie_whisper.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/MindIE/MindIE-Torch/built-in/audio/whisper_large_v3/mindie_whisper.py b/MindIE/MindIE-Torch/built-in/audio/whisper_large_v3/mindie_whisper.py index 6f58086458..36f65b76c3 100644 --- a/MindIE/MindIE-Torch/built-in/audio/whisper_large_v3/mindie_whisper.py +++ b/MindIE/MindIE-Torch/built-in/audio/whisper_large_v3/mindie_whisper.py @@ -206,7 +206,6 @@ class MindieWhisperForConditionalGeneration(WhisperForConditionalGeneration, Gen self.file_prefix_names = ["mindie_whisper_encoder_bs", "mindie_whisper_decoder_bs" ] - self._init_past_key_values() def compile(self, args): print("Start compiling Mindie-Whisper, it will take some time, please wait.") @@ -333,6 +332,7 @@ class MindieWhisperForConditionalGeneration(WhisperForConditionalGeneration, Gen but found save_path is {save_path}, batch_size is{batch_size}.") self._check_save_path(save_path, batch_size) self.batch_size = batch_size + self._init_past_key_values(self.batch_size) if not self.has_load: self.mindie_encoder = torch.jit.load(f"{save_path}/{self.file_prefix_names[0]}{batch_size}.ts") @@ -575,16 +575,16 @@ class MindieWhisperForConditionalGeneration(WhisperForConditionalGeneration, Gen if file not in file_list: raise ValueError(f"Expected file name is {file}, but can't be found in path: {save_path}") - def _init_past_key_values(self): + def _init_past_key_values(self, batch_size): init_seq_len = 0 layer_nums = 32 head_num = 20 head_size = 64 - self.init_past_key_values = [torch.randn([self.batch_size, head_num, init_seq_len, head_size]).to(torch.float16).to("npu"), - torch.randn([self.batch_size, head_num, init_seq_len, head_size]).to(torch.float16).to("npu"), - torch.randn([self.batch_size, head_num, init_seq_len, head_size]).to(torch.float16).to("npu"), - torch.randn([self.batch_size, head_num, init_seq_len, head_size]).to(torch.float16).to("npu")] * layer_nums + self.init_past_key_values = [torch.randn([batch_size, head_num, init_seq_len, head_size]).to(torch.float16).to("npu"), + torch.randn([batch_size, head_num, init_seq_len, head_size]).to(torch.float16).to("npu"), + torch.randn([batch_size, head_num, init_seq_len, head_size]).to(torch.float16).to("npu"), + torch.randn([batch_size, head_num, init_seq_len, head_size]).to(torch.float16).to("npu")] * layer_nums print("init past key value sucess.") def _get_init_past_key_values(self): -- Gitee From a5a1538b05cfa7235bdae12414965c27044d2e97 Mon Sep 17 00:00:00 2001 From: tiansongxue Date: Tue, 10 Sep 2024 09:27:04 +0800 Subject: [PATCH 04/44] fix compile decoder compile info, needed float16 --- .../audio/whisper_large_v3/mindie_whisper.py | 32 +++++++++++-------- .../built-in/audio/whisper_large_v3/readme.md | 10 ++++++ 2 files changed, 29 insertions(+), 13 deletions(-) diff --git a/MindIE/MindIE-Torch/built-in/audio/whisper_large_v3/mindie_whisper.py b/MindIE/MindIE-Torch/built-in/audio/whisper_large_v3/mindie_whisper.py index 36f65b76c3..f64ee48e09 100644 --- a/MindIE/MindIE-Torch/built-in/audio/whisper_large_v3/mindie_whisper.py +++ b/MindIE/MindIE-Torch/built-in/audio/whisper_large_v3/mindie_whisper.py @@ -279,37 +279,43 @@ class MindieWhisperForConditionalGeneration(WhisperForConditionalGeneration, Gen layer_nums = 32 max_len = 448 - encoder_outputs = torch.randn([args.bs, self.encoder_seq_len, hidden_size], dtype=torch.float16) + encoder_outputs = torch.randn([args.bs, self.encoder_seq_len, hidden_size]) decoder_input_ids = torch.randint(1, 4, (args.bs, 1)) - all_past_key_value = [torch.randn([args.bs, head_num, seq_len, head_size], dtype=torch.float16), - torch.randn([args.bs, head_num, seq_len, head_size], dtype=torch.float16), - torch.randn([args.bs, head_num, self.encoder_seq_len, head_size], dtype=torch.float16), - torch.randn([args.bs, head_num, self.encoder_seq_len, head_size], dtype=torch.float16)] * layer_nums + all_past_key_value = [torch.randn([args.bs, head_num, seq_len, head_size]), + torch.randn([args.bs, head_num, seq_len, head_size]), + torch.randn([args.bs, head_num, self.encoder_seq_len, head_size]), + torch.randn([args.bs, head_num, self.encoder_seq_len, head_size])] * layer_nums traced_args = [decoder_input_ids, encoder_outputs] traced_args.extend(all_past_key_value) traced_decoder = torch.jit.trace(model.eval(), example_inputs=traced_args) key_value_infos = [mindietorch.Input( min_shape=(args.bs, head_num, 0, head_size), - max_shape=(args.bs, head_num, max_len, head_size) + max_shape=(args.bs, head_num, max_len, head_size), + dtype=dtype.FLOAT16 ), mindietorch.Input( min_shape=(args.bs, head_num, seq_len, head_size), - max_shape=(args.bs, head_num, max_len, head_size) + max_shape=(args.bs, head_num, max_len, head_size), + dtype=dtype.FLOAT16 ), mindietorch.Input( min_shape=(args.bs, head_num, self.init_encoder_seq_len, head_size), - max_shape=(args.bs, head_num, self.encoder_seq_len, head_size) + max_shape=(args.bs, head_num, self.encoder_seq_len, head_size), + dtype=dtype.FLOAT16 ), mindietorch.Input( min_shape=(args.bs, head_num, self.init_encoder_seq_len, head_size), - max_shape=(args.bs, head_num, self.encoder_seq_len, head_size) + max_shape=(args.bs, head_num, self.encoder_seq_len, head_size), + dtype=dtype.FLOAT16 ) ] * layer_nums input_info = [mindietorch.Input(min_shape=(args.bs, 1), max_shape=(args.bs, 1), dtype=dtype.INT64), mindietorch.Input(min_shape=(args.bs, self.encoder_seq_len), max_shape=(args.bs, self.encoder_seq_len))] input_info.extend(key_value_infos) - buffer_size = math.ceil((args.bs * self.encoder_seq_len * hidden_size * 4) / 1024 / 1024) + float_size = 4 + # pre set malloc each output buffer size/MB + buffer_size = math.ceil((args.bs * self.encoder_seq_len * hidden_size * float_size) / 1024 / 1024) print(f"Set {buffer_size}/MB for each output.") compiled_decoder = mindietorch.compile(traced_decoder, inputs=input_info, @@ -317,10 +323,10 @@ class MindieWhisperForConditionalGeneration(WhisperForConditionalGeneration, Gen soc_version=args.soc_version, default_buffer_size_vec=[buffer_size]) - save_file = os.path.join(save_path, f"{self.file_prefix_names[2]}{args.bs}.ts") + save_file = os.path.join(save_path, f"{self.file_prefix_names[1]}{args.bs}.ts") torch.jit.save(compiled_decoder, save_file) - traced_file = os.path.join(save_path, f"{self.file_prefix_names[2]}{args.bs}.pt") + traced_file = os.path.join(save_path, f"{self.file_prefix_names[1]}{args.bs}.pt") torch.jit.save(traced_decoder, traced_file) print(f"Compile whisper_decoder success, saved in {save_file}.") @@ -341,7 +347,7 @@ class MindieWhisperForConditionalGeneration(WhisperForConditionalGeneration, Gen # self.mindie_decoder_prefill = torch.jit.load(f"{save_path}/{self.file_prefix_names[1]}{batch_size}.ts") # print(f"load {self.file_prefix_names[1]}{batch_size}.ts success.") - self.mindie_decoder = torch.jit.load(f"{save_path}/{self.file_prefix_names[2]}{batch_size}.ts") + self.mindie_decoder = torch.jit.load(f"{save_path}/{self.file_prefix_names[1]}{batch_size}.ts") print(f"load {self.file_prefix_names[2]}{batch_size}.ts success.") self.has_load = True else: diff --git a/MindIE/MindIE-Torch/built-in/audio/whisper_large_v3/readme.md b/MindIE/MindIE-Torch/built-in/audio/whisper_large_v3/readme.md index 2e6693e475..b1b14f6a49 100644 --- a/MindIE/MindIE-Torch/built-in/audio/whisper_large_v3/readme.md +++ b/MindIE/MindIE-Torch/built-in/audio/whisper_large_v3/readme.md @@ -54,6 +54,15 @@ pip3 install numpy==1.26.0 ``` +6. 补丁(临时,后续会修改) + ``` + cd 到transformers的安装路径,找到modeling_whisper.py文件 + 将408行的 past_key_value = (key_states, value_states) + 修改为: past_key_value = (key_states.to(torch.float16), value_states.to(torch.float16)) + ``` + + + ## 模型推理 1. 设置mindie内存池上限为16,执行如下命令设置环境变量 ``` @@ -115,6 +124,7 @@ 2. 精度指标:采用错词率作为评估指标,Aishell数据集上的错词率为10.06% + ## 模型精度基线 1. 使用torch—npu运行原始模型,得到精度极限,需要安装torch-npu。 ``` -- Gitee From 5dc832ed448043d2373633796353255471116858 Mon Sep 17 00:00:00 2001 From: tiansongxue Date: Tue, 10 Sep 2024 10:01:11 +0800 Subject: [PATCH 05/44] add mindie PFA --- .../audio/whisper_large_v3/mindie_whisper.py | 134 ++++++++++++++++-- 1 file changed, 119 insertions(+), 15 deletions(-) diff --git a/MindIE/MindIE-Torch/built-in/audio/whisper_large_v3/mindie_whisper.py b/MindIE/MindIE-Torch/built-in/audio/whisper_large_v3/mindie_whisper.py index f64ee48e09..cfc5c21e46 100644 --- a/MindIE/MindIE-Torch/built-in/audio/whisper_large_v3/mindie_whisper.py +++ b/MindIE/MindIE-Torch/built-in/audio/whisper_large_v3/mindie_whisper.py @@ -14,7 +14,7 @@ import torch.utils.checkpoint from torch import nn from transformers import WhisperForConditionalGeneration from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask -from transformers.models.whisper.modeling_whisper import WhisperDecoder,WhisperModel,WhisperDecoderLayer +from transformers.models.whisper.modeling_whisper import WhisperDecoder,WhisperModel,WhisperDecoderLayer,WhisperAttention from transformers.generation.logits_process import LogitsProcessorList from transformers.generation.stopping_criteria import StoppingCriteriaList, validate_stopping_criteria from transformers.generation.utils import GenerationMixin @@ -22,12 +22,116 @@ import mindietorch from mindietorch._enums import dtype -class MindieWhisperDecoderLayer(WhisperDecoderLayer): +class MindieAttention(WhisperAttention): + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: float = 0.0, + is_decoder: bool = False, + bias: bool = True, + is_causal: bool = False, + config = None + ): + super().__init__( + embed_dim, + num_heads, + dropout, + is_decoder, + bias, + is_causal, + config + ) + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + + is_cross_attention = key_value_states is not None + bsz, tgt_len, _ = hidden_states.size() + query_states = self.q_proj(hidden_states) + + if ( + is_cross_attention + and past_key_value is not None + and past_key_value[0].shape[2] == key_value_states.shape[1] + ): + # reuse k,v, cross_attentions + key_states = past_key_value[0] + value_states = past_key_value[1] + elif is_cross_attention: + # cross_attentions + key_states = self._shape(self.k_proj(key_value_states), -1, bsz) + value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + else: + # self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + + if self.is_decoder: + past_key_value = (key_states.to(torch.float16), value_states.to()) + + # B S N D + query_states = self._shape(query_states, tgt_len, bsz) + + key_states = key_states.transpose(1, 2).contiguous() + value_states = value_states.transpose(1, 2).contiguous() + + attn_output = torch.ops.aie.flash_attention( + query = query_states, + key = key_states, + value = value_states, + num_head = self.num_heads, + scale = self.scaleing, + layout = "BSND", + type = "PFA" + ) + + # B S N D + attn_output = attn_output.view(bsz, tgt_len, self.num_heads, self.head_dim) + + attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, _, past_key_value + + +class MindieWhisperDecoderLayer(WhisperDecoderLayer): + def __init__(self, config): super().__init__(config) self.embed_dim = config.d_model + self.self_attn = MindieAttention( + embed_dim=self.embed_dim, + num_heads=config.decoder_attention_heads, + dropout=config.attention_dropout, + is_decoder=True, + is_causal=True, + config=config + ) + + self.encoder_attn = MindieAttention( + self.embed_dim, + config.decoder_attention_heads, + dropout=config.attention_dropout, + is_decoder=True, + config=config, + ) + def forward( self, hidden_states, @@ -149,7 +253,6 @@ class MindieWhisperDecoder(WhisperDecoder): return hidden_states, past_key_value_cache - class MindieWhisperModel(WhisperModel): def __init__(self, config): @@ -295,7 +398,7 @@ class MindieWhisperForConditionalGeneration(WhisperForConditionalGeneration, Gen dtype=dtype.FLOAT16 ), mindietorch.Input( - min_shape=(args.bs, head_num, seq_len, head_size), + min_shape=(args.bs, head_num, 0, head_size), max_shape=(args.bs, head_num, max_len, head_size), dtype=dtype.FLOAT16 ), @@ -667,17 +770,6 @@ class AishellDataset(Dataset): return audio_list, transcript_list -def DataLoader(dataset:AishellDataset, batch_size=1, drop_last=True): - total_length = len(dataset) - num_batch = total_length // batch_size - start_index = 0 - for _ in range(num_batch): - end_index = min(start_index + batch_size, total_length) - yield dataset[start_index:end_index] - start_index = end_index - if drop_last and num_batch * batch_size < total_length: - return - class MindiePipeline: def __init__(self, model:MindieWhisperForConditionalGeneration, device, save_path, batch_size, *args, **kwargs): @@ -696,3 +788,15 @@ class MindiePipeline: generate_kwargs["encoder_outputs"] = self.model.mindie_encoder(inputs)[0] tokens = self.model.generate(attention_mask=None, **generate_kwargs) return tokens + + +def DataLoader(dataset:AishellDataset, batch_size=1, drop_last=True): + total_length = len(dataset) + num_batch = total_length // batch_size + start_index = 0 + for _ in range(num_batch): + end_index = min(start_index + batch_size, total_length) + yield dataset[start_index:end_index] + start_index = end_index + if drop_last and num_batch * batch_size < total_length: + return \ No newline at end of file -- Gitee From 46514b668b9b1894a8e01c0ac348ee743f296b5a Mon Sep 17 00:00:00 2001 From: tiansongxue Date: Tue, 10 Sep 2024 10:02:53 +0800 Subject: [PATCH 06/44] add mindie PFA --- MindIE/MindIE-Torch/built-in/audio/whisper_large_v3/readme.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/MindIE/MindIE-Torch/built-in/audio/whisper_large_v3/readme.md b/MindIE/MindIE-Torch/built-in/audio/whisper_large_v3/readme.md index b1b14f6a49..6016feb961 100644 --- a/MindIE/MindIE-Torch/built-in/audio/whisper_large_v3/readme.md +++ b/MindIE/MindIE-Torch/built-in/audio/whisper_large_v3/readme.md @@ -54,7 +54,7 @@ pip3 install numpy==1.26.0 ``` -6. 补丁(临时,后续会修改) +6. 补丁(临时,后续接入PFA之后会修改) ``` cd 到transformers的安装路径,找到modeling_whisper.py文件 将408行的 past_key_value = (key_states, value_states) -- Gitee From 11e2ace882b282023f5162929f2ef9292c9a2950 Mon Sep 17 00:00:00 2001 From: tiansongxue Date: Tue, 10 Sep 2024 10:15:47 +0800 Subject: [PATCH 07/44] fix init past key values shape --- .../built-in/audio/whisper_large_v3/mindie_whisper.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/MindIE/MindIE-Torch/built-in/audio/whisper_large_v3/mindie_whisper.py b/MindIE/MindIE-Torch/built-in/audio/whisper_large_v3/mindie_whisper.py index cfc5c21e46..5391d79cd6 100644 --- a/MindIE/MindIE-Torch/built-in/audio/whisper_large_v3/mindie_whisper.py +++ b/MindIE/MindIE-Torch/built-in/audio/whisper_large_v3/mindie_whisper.py @@ -692,8 +692,8 @@ class MindieWhisperForConditionalGeneration(WhisperForConditionalGeneration, Gen self.init_past_key_values = [torch.randn([batch_size, head_num, init_seq_len, head_size]).to(torch.float16).to("npu"), torch.randn([batch_size, head_num, init_seq_len, head_size]).to(torch.float16).to("npu"), - torch.randn([batch_size, head_num, init_seq_len, head_size]).to(torch.float16).to("npu"), - torch.randn([batch_size, head_num, init_seq_len, head_size]).to(torch.float16).to("npu")] * layer_nums + torch.randn([batch_size, head_num, self.init_encoder_seq_len, head_size]).to(torch.float16).to("npu"), + torch.randn([batch_size, head_num, self.init_encoder_seq_len, head_size]).to(torch.float16).to("npu")] * layer_nums print("init past key value sucess.") def _get_init_past_key_values(self): -- Gitee From 49fe0fb2375f7611ed3e729d6bc1f88e68224fab Mon Sep 17 00:00:00 2001 From: tiansongxue Date: Tue, 10 Sep 2024 11:34:50 +0800 Subject: [PATCH 08/44] rollback compile prefill decoder --- .../audio/whisper_large_v3/mindie_whisper.py | 58 +++++++++---------- 1 file changed, 28 insertions(+), 30 deletions(-) diff --git a/MindIE/MindIE-Torch/built-in/audio/whisper_large_v3/mindie_whisper.py b/MindIE/MindIE-Torch/built-in/audio/whisper_large_v3/mindie_whisper.py index 5391d79cd6..f144e3b9ad 100644 --- a/MindIE/MindIE-Torch/built-in/audio/whisper_large_v3/mindie_whisper.py +++ b/MindIE/MindIE-Torch/built-in/audio/whisper_large_v3/mindie_whisper.py @@ -47,10 +47,7 @@ class MindieAttention(WhisperAttention): self, hidden_states: torch.Tensor, key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - attention_mask: Optional[torch.Tensor] = None, - layer_head_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, + past_key_value: Optional[Tuple[torch.Tensor]] = None ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: is_cross_attention = key_value_states is not None @@ -89,6 +86,7 @@ class MindieAttention(WhisperAttention): key_states = key_states.transpose(1, 2).contiguous() value_states = value_states.transpose(1, 2).contiguous() + # B S N D attn_output = torch.ops.aie.flash_attention( query = query_states, key = key_states, @@ -99,7 +97,6 @@ class MindieAttention(WhisperAttention): type = "PFA" ) - # B S N D attn_output = attn_output.view(bsz, tgt_len, self.num_heads, self.head_dim) attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) @@ -307,6 +304,7 @@ class MindieWhisperForConditionalGeneration(WhisperForConditionalGeneration, Gen self.encoder_seq_len = 1500 self.init_encoder_seq_len = self.encoder_seq_len - 1 self.file_prefix_names = ["mindie_whisper_encoder_bs", + "mindie_decoder_prefill_bs", "mindie_whisper_decoder_bs" ] @@ -353,26 +351,26 @@ class MindieWhisperForConditionalGeneration(WhisperForConditionalGeneration, Gen torch.jit.save(encoder_traced, traced_file) print(f"Compile encoder success, saved in {save_file}") - # print("Start compiling prefill_decoder.") - # encoder_seq_len = 1500 - # hidden_size = 1280 - # encoder_outputs = torch.randn([args.bs, encoder_seq_len, hidden_size]) - # decoder_input_ids = torch.randint(1, 4, (args.bs, 1)) - # prefill_decoder_traced = torch.jit.trace(model.eval(), (decoder_input_ids, encoder_outputs)) - # input_info = [mindietorch.Input(shape=(args.bs, 1), dtype=dtype.INT64), - # mindietorch.Input(shape=(args.bs, encoder_seq_len, hidden_size))] - # prefill_decoder_compiled = mindietorch.compile(prefill_decoder_traced, - # inputs=input_info, - # precision_policy=mindietorch.PrecisionPolicy.FP16, - # soc_version=args.soc_version) + print("Start compiling prefill_decoder.") + encoder_seq_len = 1500 + hidden_size = 1280 + encoder_outputs = torch.randn([args.bs, encoder_seq_len, hidden_size]) + decoder_input_ids = torch.randint(1, 4, (args.bs, 1)) + prefill_decoder_traced = torch.jit.trace(model.eval(), (decoder_input_ids, encoder_outputs)) + input_info = [mindietorch.Input(shape=(args.bs, 1), dtype=dtype.INT64), + mindietorch.Input(shape=(args.bs, encoder_seq_len, hidden_size))] + prefill_decoder_compiled = mindietorch.compile(prefill_decoder_traced, + inputs=input_info, + precision_policy=mindietorch.PrecisionPolicy.FP16, + soc_version=args.soc_version) - # save_file = os.path.join(save_path, f"{self.file_prefix_names[1]}{args.bs}.ts") - # torch.jit.save(prefill_decoder_compiled, save_file) + save_file = os.path.join(save_path, f"{self.file_prefix_names[1]}{args.bs}.ts") + torch.jit.save(prefill_decoder_compiled, save_file) - # traced_file = os.path.join(save_path, f"{self.file_prefix_names[1]}{args.bs}.pt") - # torch.jit.save(prefill_decoder_traced, traced_file) + traced_file = os.path.join(save_path, f"{self.file_prefix_names[1]}{args.bs}.pt") + torch.jit.save(prefill_decoder_traced, traced_file) - # print(f"Compile prefill_decoder success, saved in {save_file}.") + print(f"Compile prefill_decoder success, saved in {save_file}.") print("Start compiling decoder.") hidden_size = 1280 @@ -426,10 +424,10 @@ class MindieWhisperForConditionalGeneration(WhisperForConditionalGeneration, Gen soc_version=args.soc_version, default_buffer_size_vec=[buffer_size]) - save_file = os.path.join(save_path, f"{self.file_prefix_names[1]}{args.bs}.ts") + save_file = os.path.join(save_path, f"{self.file_prefix_names[2]}{args.bs}.ts") torch.jit.save(compiled_decoder, save_file) - traced_file = os.path.join(save_path, f"{self.file_prefix_names[1]}{args.bs}.pt") + traced_file = os.path.join(save_path, f"{self.file_prefix_names[2]}{args.bs}.pt") torch.jit.save(traced_decoder, traced_file) print(f"Compile whisper_decoder success, saved in {save_file}.") @@ -441,16 +439,16 @@ class MindieWhisperForConditionalGeneration(WhisperForConditionalGeneration, Gen but found save_path is {save_path}, batch_size is{batch_size}.") self._check_save_path(save_path, batch_size) self.batch_size = batch_size - self._init_past_key_values(self.batch_size) + # self._init_past_key_values(self.batch_size) if not self.has_load: self.mindie_encoder = torch.jit.load(f"{save_path}/{self.file_prefix_names[0]}{batch_size}.ts") print(f"load {self.file_prefix_names[0]}{batch_size}.ts success.") - # self.mindie_decoder_prefill = torch.jit.load(f"{save_path}/{self.file_prefix_names[1]}{batch_size}.ts") - # print(f"load {self.file_prefix_names[1]}{batch_size}.ts success.") + self.mindie_decoder_prefill = torch.jit.load(f"{save_path}/{self.file_prefix_names[1]}{batch_size}.ts") + print(f"load {self.file_prefix_names[1]}{batch_size}.ts success.") - self.mindie_decoder = torch.jit.load(f"{save_path}/{self.file_prefix_names[1]}{batch_size}.ts") + self.mindie_decoder = torch.jit.load(f"{save_path}/{self.file_prefix_names[2]}{batch_size}.ts") print(f"load {self.file_prefix_names[2]}{batch_size}.ts success.") self.has_load = True else: @@ -514,11 +512,11 @@ class MindieWhisperForConditionalGeneration(WhisperForConditionalGeneration, Gen model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) args = [model_inputs["decoder_input_ids"].contiguous().to("npu"), model_inputs["encoder_outputs"]] if is_first_step: - args.extend(self._get_init_past_key_values()) + outputs = self.mindie_decoder_prefill(*args) is_first_step = False else: args.extend(model_inputs["past_key_values"]) - outputs = self.mindie_decoder(*args) + outputs = self.mindie_decoder(*args) if synced_gpus and this_peer_finished: continue -- Gitee From c0121947f42c92c307d12cb136fbe25024d1bc64 Mon Sep 17 00:00:00 2001 From: tiansongxue Date: Tue, 10 Sep 2024 11:47:10 +0800 Subject: [PATCH 09/44] rollback compile decoder_prefill --- .../built-in/audio/whisper_large_v3/mindie_whisper.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/MindIE/MindIE-Torch/built-in/audio/whisper_large_v3/mindie_whisper.py b/MindIE/MindIE-Torch/built-in/audio/whisper_large_v3/mindie_whisper.py index f144e3b9ad..0bf75aedb5 100644 --- a/MindIE/MindIE-Torch/built-in/audio/whisper_large_v3/mindie_whisper.py +++ b/MindIE/MindIE-Torch/built-in/audio/whisper_large_v3/mindie_whisper.py @@ -78,7 +78,7 @@ class MindieAttention(WhisperAttention): value_states = self._shape(self.v_proj(hidden_states), -1, bsz) if self.is_decoder: - past_key_value = (key_states.to(torch.float16), value_states.to()) + past_key_value = (key_states.to(torch.float16), value_states.to(torch.float16)) # B S N D query_states = self._shape(query_states, tgt_len, bsz) @@ -92,7 +92,7 @@ class MindieAttention(WhisperAttention): key = key_states, value = value_states, num_head = self.num_heads, - scale = self.scaleing, + scale = self.scaling, layout = "BSND", type = "PFA" ) -- Gitee From aa1139c06d233e6fcac4df515afe768d2ec40cd9 Mon Sep 17 00:00:00 2001 From: tiansongxue Date: Tue, 10 Sep 2024 14:23:43 +0800 Subject: [PATCH 10/44] add **kwargs param for PFA --- .../built-in/audio/whisper_large_v3/mindie_whisper.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/MindIE/MindIE-Torch/built-in/audio/whisper_large_v3/mindie_whisper.py b/MindIE/MindIE-Torch/built-in/audio/whisper_large_v3/mindie_whisper.py index 0bf75aedb5..565a4e6e95 100644 --- a/MindIE/MindIE-Torch/built-in/audio/whisper_large_v3/mindie_whisper.py +++ b/MindIE/MindIE-Torch/built-in/audio/whisper_large_v3/mindie_whisper.py @@ -47,7 +47,8 @@ class MindieAttention(WhisperAttention): self, hidden_states: torch.Tensor, key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None + past_key_value: Optional[Tuple[torch.Tensor]] = None, + **kwargs ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: is_cross_attention = key_value_states is not None -- Gitee From f578bcef38bf8223c06fef6ba9428cc5dad7de2a Mon Sep 17 00:00:00 2001 From: tiansongxue Date: Tue, 10 Sep 2024 14:27:38 +0800 Subject: [PATCH 11/44] clean code --- .../built-in/audio/whisper_large_v3/mindie_whisper.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/MindIE/MindIE-Torch/built-in/audio/whisper_large_v3/mindie_whisper.py b/MindIE/MindIE-Torch/built-in/audio/whisper_large_v3/mindie_whisper.py index 565a4e6e95..13ffecf692 100644 --- a/MindIE/MindIE-Torch/built-in/audio/whisper_large_v3/mindie_whisper.py +++ b/MindIE/MindIE-Torch/built-in/audio/whisper_large_v3/mindie_whisper.py @@ -99,11 +99,8 @@ class MindieAttention(WhisperAttention): ) attn_output = attn_output.view(bsz, tgt_len, self.num_heads, self.head_dim) - attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) - attn_output = self.out_proj(attn_output) - return attn_output, _, past_key_value @@ -120,7 +117,7 @@ class MindieWhisperDecoderLayer(WhisperDecoderLayer): is_decoder=True, is_causal=True, config=config - ) + ) self.encoder_attn = MindieAttention( self.embed_dim, @@ -159,7 +156,6 @@ class MindieWhisperDecoderLayer(WhisperDecoderLayer): # Cross-Attention Block cross_attn_present_key_value = None - cross_attn_weights = None if encoder_hidden_states is not None: residual = hidden_states hidden_states = self.encoder_attn_layer_norm(hidden_states) -- Gitee From 6f9e628df029f43cff0985e4dcf16d276a9e704d Mon Sep 17 00:00:00 2001 From: tiansongxue Date: Tue, 10 Sep 2024 15:03:24 +0800 Subject: [PATCH 12/44] fix query_states format as BSND --- .../audio/whisper_large_v3/mindie_whisper.py | 27 +++++-------------- 1 file changed, 7 insertions(+), 20 deletions(-) diff --git a/MindIE/MindIE-Torch/built-in/audio/whisper_large_v3/mindie_whisper.py b/MindIE/MindIE-Torch/built-in/audio/whisper_large_v3/mindie_whisper.py index 13ffecf692..03743383ef 100644 --- a/MindIE/MindIE-Torch/built-in/audio/whisper_large_v3/mindie_whisper.py +++ b/MindIE/MindIE-Torch/built-in/audio/whisper_large_v3/mindie_whisper.py @@ -42,7 +42,7 @@ class MindieAttention(WhisperAttention): is_causal, config ) - + def forward( self, hidden_states: torch.Tensor, @@ -82,9 +82,12 @@ class MindieAttention(WhisperAttention): past_key_value = (key_states.to(torch.float16), value_states.to(torch.float16)) # B S N D - query_states = self._shape(query_states, tgt_len, bsz) + query_states = query_states.view(bsz, tgt_len, self.num_heads, self.head_dim).contiguous() + # B S N D key_states = key_states.transpose(1, 2).contiguous() + + # B S N D value_states = value_states.transpose(1, 2).contiguous() # B S N D @@ -145,7 +148,7 @@ class MindieWhisperDecoderLayer(WhisperDecoderLayer): # Self Attention self_attn_past_key_value = past_key_value if past_key_value is not None else None - hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states, _, present_key_value = self.self_attn( hidden_states=hidden_states, past_key_value=self_attn_past_key_value, attention_mask=attention_mask, @@ -160,7 +163,7 @@ class MindieWhisperDecoderLayer(WhisperDecoderLayer): residual = hidden_states hidden_states = self.encoder_attn_layer_norm(hidden_states) - hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn( + hidden_states, _, cross_attn_present_key_value = self.encoder_attn( hidden_states=hidden_states, key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, @@ -366,7 +369,6 @@ class MindieWhisperForConditionalGeneration(WhisperForConditionalGeneration, Gen traced_file = os.path.join(save_path, f"{self.file_prefix_names[1]}{args.bs}.pt") torch.jit.save(prefill_decoder_traced, traced_file) - print(f"Compile prefill_decoder success, saved in {save_file}.") print("Start compiling decoder.") @@ -679,21 +681,6 @@ class MindieWhisperForConditionalGeneration(WhisperForConditionalGeneration, Gen if file not in file_list: raise ValueError(f"Expected file name is {file}, but can't be found in path: {save_path}") - def _init_past_key_values(self, batch_size): - init_seq_len = 0 - layer_nums = 32 - head_num = 20 - head_size = 64 - - self.init_past_key_values = [torch.randn([batch_size, head_num, init_seq_len, head_size]).to(torch.float16).to("npu"), - torch.randn([batch_size, head_num, init_seq_len, head_size]).to(torch.float16).to("npu"), - torch.randn([batch_size, head_num, self.init_encoder_seq_len, head_size]).to(torch.float16).to("npu"), - torch.randn([batch_size, head_num, self.init_encoder_seq_len, head_size]).to(torch.float16).to("npu")] * layer_nums - print("init past key value sucess.") - - def _get_init_past_key_values(self): - return self.init_past_key_values - def prepare_inputs_for_generation( self, decoder_input_ids, -- Gitee From 46553922af8f7b212ae8dd1f49cdf089e57cce37 Mon Sep 17 00:00:00 2001 From: tiansongxue Date: Tue, 10 Sep 2024 15:51:46 +0800 Subject: [PATCH 13/44] specify all_past_key_value dtype as torch.float16 --- .../built-in/audio/whisper_large_v3/mindie_whisper.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/MindIE/MindIE-Torch/built-in/audio/whisper_large_v3/mindie_whisper.py b/MindIE/MindIE-Torch/built-in/audio/whisper_large_v3/mindie_whisper.py index 03743383ef..b663a48acc 100644 --- a/MindIE/MindIE-Torch/built-in/audio/whisper_large_v3/mindie_whisper.py +++ b/MindIE/MindIE-Torch/built-in/audio/whisper_large_v3/mindie_whisper.py @@ -61,7 +61,9 @@ class MindieAttention(WhisperAttention): and past_key_value[0].shape[2] == key_value_states.shape[1] ): # reuse k,v, cross_attentions + # B S N D key_states = past_key_value[0] + # B S N D value_states = past_key_value[1] elif is_cross_attention: # cross_attentions -- Gitee From 465f88417956ebc1aef2ec141eb4c9ab3c0f2d56 Mon Sep 17 00:00:00 2001 From: tiansongxue Date: Tue, 10 Sep 2024 15:53:37 +0800 Subject: [PATCH 14/44] specify all_past_key_value dtype as torch.float16 --- .../audio/whisper_large_v3/mindie_whisper.py | 43 +++++++++---------- 1 file changed, 20 insertions(+), 23 deletions(-) diff --git a/MindIE/MindIE-Torch/built-in/audio/whisper_large_v3/mindie_whisper.py b/MindIE/MindIE-Torch/built-in/audio/whisper_large_v3/mindie_whisper.py index b663a48acc..a6578e6079 100644 --- a/MindIE/MindIE-Torch/built-in/audio/whisper_large_v3/mindie_whisper.py +++ b/MindIE/MindIE-Torch/built-in/audio/whisper_large_v3/mindie_whisper.py @@ -42,6 +42,10 @@ class MindieAttention(WhisperAttention): is_causal, config ) + + # BSND + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).contiguous() def forward( self, @@ -61,9 +65,7 @@ class MindieAttention(WhisperAttention): and past_key_value[0].shape[2] == key_value_states.shape[1] ): # reuse k,v, cross_attentions - # B S N D key_states = past_key_value[0] - # B S N D value_states = past_key_value[1] elif is_cross_attention: # cross_attentions @@ -73,8 +75,9 @@ class MindieAttention(WhisperAttention): # reuse k, v, self_attention key_states = self._shape(self.k_proj(hidden_states), -1, bsz) value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) + # concat at seq_len dim + key_states = torch.cat([past_key_value[0], key_states], dim=1) + value_states = torch.cat([past_key_value[1], value_states], dim=1) else: # self_attention key_states = self._shape(self.k_proj(hidden_states), -1, bsz) @@ -85,16 +88,10 @@ class MindieAttention(WhisperAttention): # B S N D query_states = query_states.view(bsz, tgt_len, self.num_heads, self.head_dim).contiguous() - - # B S N D - key_states = key_states.transpose(1, 2).contiguous() - - # B S N D - value_states = value_states.transpose(1, 2).contiguous() # B S N D attn_output = torch.ops.aie.flash_attention( - query = query_states, + query = query_states.to(torch.float16), key = key_states, value = value_states, num_head = self.num_heads, @@ -383,32 +380,32 @@ class MindieWhisperForConditionalGeneration(WhisperForConditionalGeneration, Gen encoder_outputs = torch.randn([args.bs, self.encoder_seq_len, hidden_size]) decoder_input_ids = torch.randint(1, 4, (args.bs, 1)) - all_past_key_value = [torch.randn([args.bs, head_num, seq_len, head_size]), - torch.randn([args.bs, head_num, seq_len, head_size]), - torch.randn([args.bs, head_num, self.encoder_seq_len, head_size]), - torch.randn([args.bs, head_num, self.encoder_seq_len, head_size])] * layer_nums + all_past_key_value = [torch.randn([args.bs, head_num, seq_len, head_size], dtype=torch.float16), + torch.randn([args.bs, head_num, seq_len, head_size], dtype=torch.float16), + torch.randn([args.bs, head_num, self.encoder_seq_len, head_size], dtype=torch.float16), + torch.randn([args.bs, head_num, self.encoder_seq_len, head_size], dtype=torch.float16)] * layer_nums traced_args = [decoder_input_ids, encoder_outputs] traced_args.extend(all_past_key_value) traced_decoder = torch.jit.trace(model.eval(), example_inputs=traced_args) - + # BSND key_value_infos = [mindietorch.Input( - min_shape=(args.bs, head_num, 0, head_size), - max_shape=(args.bs, head_num, max_len, head_size), + min_shape=(args.bs, 0, head_num, head_size), + max_shape=(args.bs, max_len, head_num, head_size), dtype=dtype.FLOAT16 ), mindietorch.Input( - min_shape=(args.bs, head_num, 0, head_size), - max_shape=(args.bs, head_num, max_len, head_size), + min_shape=(args.bs, 0, head_num, head_size), + max_shape=(args.bs, max_len, head_num, head_size), dtype=dtype.FLOAT16 ), mindietorch.Input( - min_shape=(args.bs, head_num, self.init_encoder_seq_len, head_size), + min_shape=(args.bs, head_num, self.encoder_seq_len, head_size), max_shape=(args.bs, head_num, self.encoder_seq_len, head_size), dtype=dtype.FLOAT16 ), mindietorch.Input( - min_shape=(args.bs, head_num, self.init_encoder_seq_len, head_size), - max_shape=(args.bs, head_num, self.encoder_seq_len, head_size), + min_shape=(args.bs, self.encoder_seq_len, head_num, head_size), + max_shape=(args.bs, self.encoder_seq_len, head_num, head_size), dtype=dtype.FLOAT16 ) ] * layer_nums -- Gitee From d94ce436d8eb8accd96d3df6351cca6285ded3ac Mon Sep 17 00:00:00 2001 From: tiansongxue Date: Tue, 10 Sep 2024 19:18:31 +0800 Subject: [PATCH 15/44] add mindieencoder --- .../audio/whisper_large_v3/mindie_whisper.py | 34 +++++++++++++++---- 1 file changed, 27 insertions(+), 7 deletions(-) diff --git a/MindIE/MindIE-Torch/built-in/audio/whisper_large_v3/mindie_whisper.py b/MindIE/MindIE-Torch/built-in/audio/whisper_large_v3/mindie_whisper.py index a6578e6079..2ad043bf24 100644 --- a/MindIE/MindIE-Torch/built-in/audio/whisper_large_v3/mindie_whisper.py +++ b/MindIE/MindIE-Torch/built-in/audio/whisper_large_v3/mindie_whisper.py @@ -14,7 +14,8 @@ import torch.utils.checkpoint from torch import nn from transformers import WhisperForConditionalGeneration from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask -from transformers.models.whisper.modeling_whisper import WhisperDecoder,WhisperModel,WhisperDecoderLayer,WhisperAttention +from transformers.models.whisper.modeling_whisper import WhisperDecoder, WhisperEncoder,WhisperEncoderLayer,\ +WhisperModel,WhisperDecoderLayer,WhisperAttention from transformers.generation.logits_process import LogitsProcessorList from transformers.generation.stopping_criteria import StoppingCriteriaList, validate_stopping_criteria from transformers.generation.utils import GenerationMixin @@ -248,6 +249,25 @@ class MindieWhisperDecoder(WhisperDecoder): hidden_states = self.layer_norm(hidden_states) return hidden_states, past_key_value_cache +class MindieWhisperEncoderLayer(WhisperEncoderLayer): + def __init__(self, config): + super().__init__(config) + self.self_attn = MindieAttention( + embed_dim=self.embed_dim, + num_heads=config.encoder_attention_heads, + dropout=config.attention_dropout, + config=config, + ) + + +class MindieWhisperEncoder(WhisperEncoder): + + def __init__(self, config): + super().__init__(config) + self.layers = nn.ModuleList([WhisperEncoderLayer(config) for _ in range(config.encoder_layers)]) + + + class MindieWhisperModel(WhisperModel): @@ -362,12 +382,12 @@ class MindieWhisperForConditionalGeneration(WhisperForConditionalGeneration, Gen inputs=input_info, precision_policy=mindietorch.PrecisionPolicy.FP16, soc_version=args.soc_version) - - save_file = os.path.join(save_path, f"{self.file_prefix_names[1]}{args.bs}.ts") - torch.jit.save(prefill_decoder_compiled, save_file) traced_file = os.path.join(save_path, f"{self.file_prefix_names[1]}{args.bs}.pt") torch.jit.save(prefill_decoder_traced, traced_file) + + save_file = os.path.join(save_path, f"{self.file_prefix_names[1]}{args.bs}.ts") + torch.jit.save(prefill_decoder_compiled, save_file) print(f"Compile prefill_decoder success, saved in {save_file}.") print("Start compiling decoder.") @@ -422,12 +442,12 @@ class MindieWhisperForConditionalGeneration(WhisperForConditionalGeneration, Gen soc_version=args.soc_version, default_buffer_size_vec=[buffer_size]) - save_file = os.path.join(save_path, f"{self.file_prefix_names[2]}{args.bs}.ts") - torch.jit.save(compiled_decoder, save_file) - traced_file = os.path.join(save_path, f"{self.file_prefix_names[2]}{args.bs}.pt") torch.jit.save(traced_decoder, traced_file) + save_file = os.path.join(save_path, f"{self.file_prefix_names[2]}{args.bs}.ts") + torch.jit.save(compiled_decoder, save_file) + print(f"Compile whisper_decoder success, saved in {save_file}.") self.has_compile = True -- Gitee From 35249202470e5347a803e777d76a5c349b25d683 Mon Sep 17 00:00:00 2001 From: tiansongxue Date: Tue, 10 Sep 2024 19:20:23 +0800 Subject: [PATCH 16/44] add mindieencoder --- .../built-in/audio/whisper_large_v3/mindie_whisper.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/MindIE/MindIE-Torch/built-in/audio/whisper_large_v3/mindie_whisper.py b/MindIE/MindIE-Torch/built-in/audio/whisper_large_v3/mindie_whisper.py index 2ad043bf24..4642e4afdf 100644 --- a/MindIE/MindIE-Torch/built-in/audio/whisper_large_v3/mindie_whisper.py +++ b/MindIE/MindIE-Torch/built-in/audio/whisper_large_v3/mindie_whisper.py @@ -265,8 +265,6 @@ class MindieWhisperEncoder(WhisperEncoder): def __init__(self, config): super().__init__(config) self.layers = nn.ModuleList([WhisperEncoderLayer(config) for _ in range(config.encoder_layers)]) - - class MindieWhisperModel(WhisperModel): @@ -274,6 +272,7 @@ class MindieWhisperModel(WhisperModel): def __init__(self, config): super().__init__(config) self.decoder = MindieWhisperDecoder(config) + self.encoder = MindieWhisperEncoder(config) def forward( self, -- Gitee From df77726f37ce6a85894e0884760bb6e30274d24b Mon Sep 17 00:00:00 2001 From: tiansongxue Date: Wed, 11 Sep 2024 09:10:28 +0800 Subject: [PATCH 17/44] remove PFA --- .../audio/whisper_large_v3/mindie_whisper.py | 48 +++++++++---------- 1 file changed, 24 insertions(+), 24 deletions(-) diff --git a/MindIE/MindIE-Torch/built-in/audio/whisper_large_v3/mindie_whisper.py b/MindIE/MindIE-Torch/built-in/audio/whisper_large_v3/mindie_whisper.py index 4642e4afdf..4ef037bf2b 100644 --- a/MindIE/MindIE-Torch/built-in/audio/whisper_large_v3/mindie_whisper.py +++ b/MindIE/MindIE-Torch/built-in/audio/whisper_large_v3/mindie_whisper.py @@ -113,22 +113,22 @@ class MindieWhisperDecoderLayer(WhisperDecoderLayer): super().__init__(config) self.embed_dim = config.d_model - self.self_attn = MindieAttention( - embed_dim=self.embed_dim, - num_heads=config.decoder_attention_heads, - dropout=config.attention_dropout, - is_decoder=True, - is_causal=True, - config=config - ) - - self.encoder_attn = MindieAttention( - self.embed_dim, - config.decoder_attention_heads, - dropout=config.attention_dropout, - is_decoder=True, - config=config, - ) + # self.self_attn = MindieAttention( + # embed_dim=self.embed_dim, + # num_heads=config.decoder_attention_heads, + # dropout=config.attention_dropout, + # is_decoder=True, + # is_causal=True, + # config=config + # ) + + # self.encoder_attn = MindieAttention( + # self.embed_dim, + # config.decoder_attention_heads, + # dropout=config.attention_dropout, + # is_decoder=True, + # config=config, + # ) def forward( self, @@ -399,22 +399,22 @@ class MindieWhisperForConditionalGeneration(WhisperForConditionalGeneration, Gen encoder_outputs = torch.randn([args.bs, self.encoder_seq_len, hidden_size]) decoder_input_ids = torch.randint(1, 4, (args.bs, 1)) - all_past_key_value = [torch.randn([args.bs, head_num, seq_len, head_size], dtype=torch.float16), - torch.randn([args.bs, head_num, seq_len, head_size], dtype=torch.float16), - torch.randn([args.bs, head_num, self.encoder_seq_len, head_size], dtype=torch.float16), - torch.randn([args.bs, head_num, self.encoder_seq_len, head_size], dtype=torch.float16)] * layer_nums + all_past_key_value = [torch.randn([args.bs, head_num, seq_len, head_size]), + torch.randn([args.bs, head_num, seq_len, head_size]), + torch.randn([args.bs, head_num, self.encoder_seq_len, head_size]), + torch.randn([args.bs, head_num, self.encoder_seq_len, head_size])] * layer_nums traced_args = [decoder_input_ids, encoder_outputs] traced_args.extend(all_past_key_value) traced_decoder = torch.jit.trace(model.eval(), example_inputs=traced_args) # BSND key_value_infos = [mindietorch.Input( - min_shape=(args.bs, 0, head_num, head_size), - max_shape=(args.bs, max_len, head_num, head_size), + min_shape=(args.bs, head_num, 1, head_size), + max_shape=(args.bs, head_num, max_len, head_size), dtype=dtype.FLOAT16 ), mindietorch.Input( - min_shape=(args.bs, 0, head_num, head_size), - max_shape=(args.bs, max_len, head_num, head_size), + min_shape=(args.bs, head_num, 1, head_size), + max_shape=(args.bs, head_num, max_len, head_size), dtype=dtype.FLOAT16 ), mindietorch.Input( -- Gitee From 677318c59987bda3983f29b585a5ed931737838a Mon Sep 17 00:00:00 2001 From: tiansongxue Date: Wed, 11 Sep 2024 10:25:43 +0800 Subject: [PATCH 18/44] remove double call from_pretrained --- .../built-in/audio/whisper_large_v3/mindie_whisper.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/MindIE/MindIE-Torch/built-in/audio/whisper_large_v3/mindie_whisper.py b/MindIE/MindIE-Torch/built-in/audio/whisper_large_v3/mindie_whisper.py index 4ef037bf2b..df13a82748 100644 --- a/MindIE/MindIE-Torch/built-in/audio/whisper_large_v3/mindie_whisper.py +++ b/MindIE/MindIE-Torch/built-in/audio/whisper_large_v3/mindie_whisper.py @@ -339,8 +339,7 @@ class MindieWhisperForConditionalGeneration(WhisperForConditionalGeneration, Gen else: print(f"Directory {save_path} already exists.") - model = MindieWhisperForConditionalGeneration.from_pretrained(model_path) - encoder = model.get_encoder() + encoder = self.get_encoder() class Encoder(torch.nn.Module): -- Gitee From 4e5594f943452201fee94ecb0914440b982c5a0a Mon Sep 17 00:00:00 2001 From: tiansongxue Date: Wed, 11 Sep 2024 11:05:09 +0800 Subject: [PATCH 19/44] rollback to decoder bnsd --- .../built-in/audio/whisper_large_v3/mindie_whisper.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/MindIE/MindIE-Torch/built-in/audio/whisper_large_v3/mindie_whisper.py b/MindIE/MindIE-Torch/built-in/audio/whisper_large_v3/mindie_whisper.py index df13a82748..44e8141768 100644 --- a/MindIE/MindIE-Torch/built-in/audio/whisper_large_v3/mindie_whisper.py +++ b/MindIE/MindIE-Torch/built-in/audio/whisper_large_v3/mindie_whisper.py @@ -373,7 +373,7 @@ class MindieWhisperForConditionalGeneration(WhisperForConditionalGeneration, Gen hidden_size = 1280 encoder_outputs = torch.randn([args.bs, encoder_seq_len, hidden_size]) decoder_input_ids = torch.randint(1, 4, (args.bs, 1)) - prefill_decoder_traced = torch.jit.trace(model.eval(), (decoder_input_ids, encoder_outputs)) + prefill_decoder_traced = torch.jit.trace(self.eval(), (decoder_input_ids, encoder_outputs)) input_info = [mindietorch.Input(shape=(args.bs, 1), dtype=dtype.INT64), mindietorch.Input(shape=(args.bs, encoder_seq_len, hidden_size))] prefill_decoder_compiled = mindietorch.compile(prefill_decoder_traced, @@ -404,7 +404,7 @@ class MindieWhisperForConditionalGeneration(WhisperForConditionalGeneration, Gen torch.randn([args.bs, head_num, self.encoder_seq_len, head_size])] * layer_nums traced_args = [decoder_input_ids, encoder_outputs] traced_args.extend(all_past_key_value) - traced_decoder = torch.jit.trace(model.eval(), example_inputs=traced_args) + traced_decoder = torch.jit.trace(self.eval(), example_inputs=traced_args) # BSND key_value_infos = [mindietorch.Input( min_shape=(args.bs, head_num, 1, head_size), @@ -422,8 +422,8 @@ class MindieWhisperForConditionalGeneration(WhisperForConditionalGeneration, Gen dtype=dtype.FLOAT16 ), mindietorch.Input( - min_shape=(args.bs, self.encoder_seq_len, head_num, head_size), - max_shape=(args.bs, self.encoder_seq_len, head_num, head_size), + min_shape=(args.bs, head_num, self.encoder_seq_len, head_size), + max_shape=(args.bs, head_num, self.encoder_seq_len, head_size), dtype=dtype.FLOAT16 ) ] * layer_nums -- Gitee From 91bff617d4973d6a59c8afc4ad19624986064847 Mon Sep 17 00:00:00 2001 From: tiansongxue Date: Wed, 11 Sep 2024 17:06:56 +0800 Subject: [PATCH 20/44] add if todo depend on PFA layout (bsnd or bnsd) --- .../built-in/audio/whisper_large_v3/mindie_whisper.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/MindIE/MindIE-Torch/built-in/audio/whisper_large_v3/mindie_whisper.py b/MindIE/MindIE-Torch/built-in/audio/whisper_large_v3/mindie_whisper.py index 44e8141768..fcef9f0b5d 100644 --- a/MindIE/MindIE-Torch/built-in/audio/whisper_large_v3/mindie_whisper.py +++ b/MindIE/MindIE-Torch/built-in/audio/whisper_large_v3/mindie_whisper.py @@ -63,6 +63,7 @@ class MindieAttention(WhisperAttention): if ( is_cross_attention and past_key_value is not None + # if == bnsd shape[2] else bsnd shape[1] and past_key_value[0].shape[2] == key_value_states.shape[1] ): # reuse k,v, cross_attentions @@ -427,8 +428,13 @@ class MindieWhisperForConditionalGeneration(WhisperForConditionalGeneration, Gen dtype=dtype.FLOAT16 ) ] * layer_nums - input_info = [mindietorch.Input(min_shape=(args.bs, 1), max_shape=(args.bs, 1), dtype=dtype.INT64), - mindietorch.Input(min_shape=(args.bs, self.encoder_seq_len), max_shape=(args.bs, self.encoder_seq_len))] + input_info = [mindietorch.Input(min_shape=(args.bs, 1), + max_shape=(args.bs, 1), dtype=dtype.INT64 + ), + mindietorch.Input(min_shape=(args.bs, self.encoder_seq_len, hidden_size), + max_shape=(args.bs, self.encoder_seq_len, hidden_size) + ) + ] input_info.extend(key_value_infos) float_size = 4 # pre set malloc each output buffer size/MB -- Gitee From b3a2c9dd11f3ba4c117132550bc6026d1ba89a12 Mon Sep 17 00:00:00 2001 From: tiansongxue Date: Wed, 11 Sep 2024 17:16:01 +0800 Subject: [PATCH 21/44] fix PFA layout as BSND --- .../built-in/audio/whisper_large_v3/mindie_whisper.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/MindIE/MindIE-Torch/built-in/audio/whisper_large_v3/mindie_whisper.py b/MindIE/MindIE-Torch/built-in/audio/whisper_large_v3/mindie_whisper.py index fcef9f0b5d..ad080f97e1 100644 --- a/MindIE/MindIE-Torch/built-in/audio/whisper_large_v3/mindie_whisper.py +++ b/MindIE/MindIE-Torch/built-in/audio/whisper_large_v3/mindie_whisper.py @@ -64,7 +64,7 @@ class MindieAttention(WhisperAttention): is_cross_attention and past_key_value is not None # if == bnsd shape[2] else bsnd shape[1] - and past_key_value[0].shape[2] == key_value_states.shape[1] + and past_key_value[0].shape[1] == key_value_states.shape[1] ): # reuse k,v, cross_attentions key_states = past_key_value[0] @@ -93,7 +93,7 @@ class MindieAttention(WhisperAttention): # B S N D attn_output = torch.ops.aie.flash_attention( - query = query_states.to(torch.float16), + query = query_states, key = key_states, value = value_states, num_head = self.num_heads, @@ -406,7 +406,7 @@ class MindieWhisperForConditionalGeneration(WhisperForConditionalGeneration, Gen traced_args = [decoder_input_ids, encoder_outputs] traced_args.extend(all_past_key_value) traced_decoder = torch.jit.trace(self.eval(), example_inputs=traced_args) - # BSND + # if using PFA, the BSND layout needed key_value_infos = [mindietorch.Input( min_shape=(args.bs, head_num, 1, head_size), max_shape=(args.bs, head_num, max_len, head_size), -- Gitee From 12798e0665a8a841773f07d202a35f0f860cd680 Mon Sep 17 00:00:00 2001 From: tiansongxue Date: Thu, 12 Sep 2024 10:52:03 +0800 Subject: [PATCH 22/44] remove PFA at encoder --- .../built-in/audio/whisper_large_v3/mindie_whisper.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/MindIE/MindIE-Torch/built-in/audio/whisper_large_v3/mindie_whisper.py b/MindIE/MindIE-Torch/built-in/audio/whisper_large_v3/mindie_whisper.py index ad080f97e1..066e669cd4 100644 --- a/MindIE/MindIE-Torch/built-in/audio/whisper_large_v3/mindie_whisper.py +++ b/MindIE/MindIE-Torch/built-in/audio/whisper_large_v3/mindie_whisper.py @@ -273,7 +273,7 @@ class MindieWhisperModel(WhisperModel): def __init__(self, config): super().__init__(config) self.decoder = MindieWhisperDecoder(config) - self.encoder = MindieWhisperEncoder(config) + # self.encoder = MindieWhisperEncoder(config) def forward( self, -- Gitee From e58412b77ec06e53c12f0f02cf2eca7166943ddb Mon Sep 17 00:00:00 2001 From: tiansongxue Date: Thu, 12 Sep 2024 14:19:42 +0800 Subject: [PATCH 23/44] fix label[1:-1] --- .../built-in/audio/whisper_large_v3/mindie_whisper.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/MindIE/MindIE-Torch/built-in/audio/whisper_large_v3/mindie_whisper.py b/MindIE/MindIE-Torch/built-in/audio/whisper_large_v3/mindie_whisper.py index 066e669cd4..497fbce4a4 100644 --- a/MindIE/MindIE-Torch/built-in/audio/whisper_large_v3/mindie_whisper.py +++ b/MindIE/MindIE-Torch/built-in/audio/whisper_large_v3/mindie_whisper.py @@ -750,7 +750,7 @@ class AishellDataset(Dataset): with open(self.label_file, 'r', encoding='utf-8') as file: for line in file: name = line.strip().split(" ")[0] - trans = "".join(x for x in line.split(" ")[1:-1]) + trans = "".join(x for x in line.split(" ")[1:]) self.data_label_map[name] = trans def audio_2_trans(self): -- Gitee From baa6987961ef38e501b6a75bc56897fd2e0c3d83 Mon Sep 17 00:00:00 2001 From: tiansongxue Date: Thu, 12 Sep 2024 15:23:31 +0800 Subject: [PATCH 24/44] add strip() --- .../built-in/audio/whisper_large_v3/mindie_whisper.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/MindIE/MindIE-Torch/built-in/audio/whisper_large_v3/mindie_whisper.py b/MindIE/MindIE-Torch/built-in/audio/whisper_large_v3/mindie_whisper.py index 497fbce4a4..90f115d997 100644 --- a/MindIE/MindIE-Torch/built-in/audio/whisper_large_v3/mindie_whisper.py +++ b/MindIE/MindIE-Torch/built-in/audio/whisper_large_v3/mindie_whisper.py @@ -750,7 +750,7 @@ class AishellDataset(Dataset): with open(self.label_file, 'r', encoding='utf-8') as file: for line in file: name = line.strip().split(" ")[0] - trans = "".join(x for x in line.split(" ")[1:]) + trans = "".join(x for x in line.strip().split(" ")[1:]) self.data_label_map[name] = trans def audio_2_trans(self): -- Gitee From b677cb60c71e3d056a078def067d5815657092ef Mon Sep 17 00:00:00 2001 From: tiansongxue Date: Mon, 23 Sep 2024 17:35:23 +0800 Subject: [PATCH 25/44] fix DTS --- .../built-in/audio/whisper_large_v3/mindie_whisper.py | 2 +- .../MindIE-Torch/built-in/audio/whisper_large_v3/readme.md | 5 ++--- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/MindIE/MindIE-Torch/built-in/audio/whisper_large_v3/mindie_whisper.py b/MindIE/MindIE-Torch/built-in/audio/whisper_large_v3/mindie_whisper.py index e0aeb3634b..0fc86f5b63 100644 --- a/MindIE/MindIE-Torch/built-in/audio/whisper_large_v3/mindie_whisper.py +++ b/MindIE/MindIE-Torch/built-in/audio/whisper_large_v3/mindie_whisper.py @@ -628,7 +628,7 @@ class AishellDataset(Dataset): self.label_file = trans_file self.data_label_map = {} self.sampling_rate = sampling_rate - self.DURATION = 49762.8 + self.DURATION = 36108.92 self.audio_2_trans() diff --git a/MindIE/MindIE-Torch/built-in/audio/whisper_large_v3/readme.md b/MindIE/MindIE-Torch/built-in/audio/whisper_large_v3/readme.md index 2e6693e475..4c1175438f 100644 --- a/MindIE/MindIE-Torch/built-in/audio/whisper_large_v3/readme.md +++ b/MindIE/MindIE-Torch/built-in/audio/whisper_large_v3/readme.md @@ -58,11 +58,10 @@ 1. 设置mindie内存池上限为16,执行如下命令设置环境变量 ``` export TORCH_AIE_NPU_CACHE_MAX_SIZE=16 - export ASCENDIE_FASTER_MODE=1 ``` 2. 模型编译和推理 - 执行下述命令进行模型编译和推理,sample.py脚本配套使用的数据集链接https://openslr.org/33/,下载得到data_aishell.tgz。如需使用其他的数据集,请修改对应的数据集解析方式。 + 执行下述命令进行模型编译和推理,sample.py脚本配套使用的数据集链接https://openslr.org/33,下载得到data_aishell.tgz。如需使用其他的数据集,请修改对应的数据集解析方式。 ```bash tar -xzf data_aishell.tgz cd ./data_aishell/wav @@ -75,7 +74,7 @@ | |--aishell_transcript_v0.8.txt # 标签数据 |---wav | |--S0*.tar.gz - | |--test # 测试语料,约10小时 + | |--test | |--dev | |--train ``` -- Gitee From 519099f42df76b3b8e88a34850ae4a60fb1ea0f8 Mon Sep 17 00:00:00 2001 From: tiansongxue Date: Mon, 23 Sep 2024 19:09:46 +0800 Subject: [PATCH 26/44] fix readme add max_cache 32G fix 910B error as decoder add PFA fix input_ids shape encoder add aoe job 1 encoder add PFA fix DTS --- .../audio/whisper_large_v3/mindie_whisper.py | 69 +++++-------------- .../built-in/audio/whisper_large_v3/readme.md | 31 +++++---- .../built-in/audio/whisper_large_v3/sample.py | 4 ++ 3 files changed, 37 insertions(+), 67 deletions(-) diff --git a/MindIE/MindIE-Torch/built-in/audio/whisper_large_v3/mindie_whisper.py b/MindIE/MindIE-Torch/built-in/audio/whisper_large_v3/mindie_whisper.py index 8d3eb65e00..4694e44191 100644 --- a/MindIE/MindIE-Torch/built-in/audio/whisper_large_v3/mindie_whisper.py +++ b/MindIE/MindIE-Torch/built-in/audio/whisper_large_v3/mindie_whisper.py @@ -56,37 +56,12 @@ class MindieAttention(WhisperAttention): **kwargs ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - is_cross_attention = key_value_states is not None bsz, tgt_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states) - - if ( - is_cross_attention - and past_key_value is not None - # if == bnsd shape[2] else bsnd shape[1] - and past_key_value[0].shape[1] == key_value_states.shape[1] - ): - # reuse k,v, cross_attentions - key_states = past_key_value[0] - value_states = past_key_value[1] - elif is_cross_attention: - # cross_attentions - key_states = self._shape(self.k_proj(key_value_states), -1, bsz) - value_states = self._shape(self.v_proj(key_value_states), -1, bsz) - elif past_key_value is not None: - # reuse k, v, self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - # concat at seq_len dim - key_states = torch.cat([past_key_value[0], key_states], dim=1) - value_states = torch.cat([past_key_value[1], value_states], dim=1) - else: - # self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - if self.is_decoder: - past_key_value = (key_states.to(torch.float16), value_states.to(torch.float16)) + # self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) # B S N D query_states = query_states.view(bsz, tgt_len, self.num_heads, self.head_dim).contiguous() @@ -105,7 +80,7 @@ class MindieAttention(WhisperAttention): attn_output = attn_output.view(bsz, tgt_len, self.num_heads, self.head_dim) attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) attn_output = self.out_proj(attn_output) - return attn_output, _, past_key_value + return attn_output, _, _ class MindieWhisperDecoderLayer(WhisperDecoderLayer): @@ -114,23 +89,6 @@ class MindieWhisperDecoderLayer(WhisperDecoderLayer): super().__init__(config) self.embed_dim = config.d_model - # self.self_attn = MindieAttention( - # embed_dim=self.embed_dim, - # num_heads=config.decoder_attention_heads, - # dropout=config.attention_dropout, - # is_decoder=True, - # is_causal=True, - # config=config - # ) - - # self.encoder_attn = MindieAttention( - # self.embed_dim, - # config.decoder_attention_heads, - # dropout=config.attention_dropout, - # is_decoder=True, - # config=config, - # ) - def forward( self, hidden_states, @@ -265,7 +223,7 @@ class MindieWhisperEncoder(WhisperEncoder): def __init__(self, config): super().__init__(config) - self.layers = nn.ModuleList([WhisperEncoderLayer(config) for _ in range(config.encoder_layers)]) + self.layers = nn.ModuleList([MindieWhisperEncoderLayer(config) for _ in range(config.encoder_layers)]) class MindieWhisperModel(WhisperModel): @@ -273,7 +231,7 @@ class MindieWhisperModel(WhisperModel): def __init__(self, config): super().__init__(config) self.decoder = MindieWhisperDecoder(config) - # self.encoder = MindieWhisperEncoder(config) + self.encoder = MindieWhisperEncoder(config) def forward( self, @@ -356,12 +314,19 @@ class MindieWhisperForConditionalGeneration(WhisperForConditionalGeneration, Gen input_features = torch.randn([args.bs, mel_feature_size, max_frames]) encoder_traced = torch.jit.trace(Encoder(encoder), (input_features)) input_info = [mindietorch.Input(shape=(args.bs, mel_feature_size, max_frames))] - print("Start compiling encoder.") + print("Start AOE optimization_level 1.") + compiled = mindietorch.compile(encoder_traced, + inputs=input_info, + precision_policy=mindietorch.PrecisionPolicy.FP16, + soc_version=args.soc_version, + optimization_level=1 + ) + print("AOE optimize finished and start compile encoder.") compiled = mindietorch.compile(encoder_traced, inputs=input_info, precision_policy=mindietorch.PrecisionPolicy.FP16, - soc_version=args.soc_version - ) + soc_version=args.soc_version, + ) save_file = os.path.join(save_path, f"{self.file_prefix_names[0]}{args.bs}.ts") torch.jit.save(compiled, save_file) @@ -712,7 +677,7 @@ class MindieWhisperForConditionalGeneration(WhisperForConditionalGeneration, Gen **kwargs ): if past_key_values is not None: - past_key_values_shape = past_key_values[0].shape + past_key_values_shape = past_key_values[0].shape # B N S D past_length = past_key_values_shape[-2] decoder_input_ids_shape = decoder_input_ids.shape # Some generation methods already pass only the last input ID diff --git a/MindIE/MindIE-Torch/built-in/audio/whisper_large_v3/readme.md b/MindIE/MindIE-Torch/built-in/audio/whisper_large_v3/readme.md index 76c4f8d80f..3b3d26ed4a 100644 --- a/MindIE/MindIE-Torch/built-in/audio/whisper_large_v3/readme.md +++ b/MindIE/MindIE-Torch/built-in/audio/whisper_large_v3/readme.md @@ -17,13 +17,13 @@ **表 1** 版本配套表 - | 配套 | 版本 | 环境准备指导 | - | ------ | ------- | ------------ | - | Python | 3.10.13 | - | - | torch | 2.1.0+cpu | - | - | torch_audio | 2.1.0+cpu | - | - | CANN | 8.0.RC2 | - | - | MindIE | 1.0.RC2.B091 | - | + | 配套 | 版本 | 环境准备指导 | + |-----------| ------- | ------------ | + | Python | 3.10.13 | - | + | torch | 2.1.0+cpu | - | + | torch_audio | 2.1.0+cpu | - | + | CANN | 8.0.RC3 | - | + | MindIE | 1.0.RC3 | - | # 快速上手 ## 获取源码 @@ -42,7 +42,7 @@ ```bash https://huggingface.co/openai/whisper-large-v3/tree/main ``` - 将权重文件存放至当前目录下的mode_path文件夹,请先创建改文件夹。 + 将权重文件存放至当前目录下的model_path文件夹,请先创建改文件夹。 5. 安装依赖 @@ -54,23 +54,24 @@ pip3 install numpy==1.26.0 ``` -6. 补丁(临时,后续接入PFA之后会修改) +6. 补丁 ``` cd 到transformers的安装路径,找到modeling_whisper.py文件 将408行的 past_key_value = (key_states, value_states) 修改为: past_key_value = (key_states.to(torch.float16), value_states.to(torch.float16)) ``` - + ## 模型推理 -1. 设置mindie内存池上限为16,执行如下命令设置环境变量 +1. 设置mindie内存池上限为32,执行如下命令设置环境变量。内存池设置过小,内存重复申请和释放会影响性能。 ``` - export TORCH_AIE_NPU_CACHE_MAX_SIZE=16 + export TORCH_AIE_NPU_CACHE_MAX_SIZE=32 ``` 2. 模型编译和推理 - 执行下述命令进行模型编译和推理,sample.py脚本配套使用的数据集链接https://openslr.org/33,下载得到data_aishell.tgz。如需使用其他的数据集,请修改对应的数据集解析方式。 + 执行下述命令进行模型编译和推理,sample.py脚本配套使用的数据集链接https://openslr.org/33 + 下载得到data_aishell.tgz。如需使用其他的数据集,请修改对应的数据集解析方式。 ```bash tar -xzf data_aishell.tgz cd ./data_aishell/wav @@ -118,7 +119,7 @@ 4. 仅支持使用MindiePipeline拉起推理流程 评估指标: - 1. 性能指标:bs=8(当前最优,带持续优化)时,推理aishell数据集需要5200+秒,转录比(语料总时长/推理总时长)6.8倍。 + 1. 性能指标:bs=8时,转录比(语料总时长/推理总时长)6.3倍。 2. 精度指标:采用错词率作为评估指标,Aishell数据集上的错词率为10.06% @@ -129,7 +130,7 @@ ``` pip3 install torch-npu==2.1.0 ``` -2. 执行test_tools脚本。(改脚本不作为交付件,开发和自测对齐精度使用) +2. 执行test_tools脚本。 ```bash python3 test_tools.py \ -model_path ./model_path \ diff --git a/MindIE/MindIE-Torch/built-in/audio/whisper_large_v3/sample.py b/MindIE/MindIE-Torch/built-in/audio/whisper_large_v3/sample.py index 06ebb8af24..b3a3099085 100644 --- a/MindIE/MindIE-Torch/built-in/audio/whisper_large_v3/sample.py +++ b/MindIE/MindIE-Torch/built-in/audio/whisper_large_v3/sample.py @@ -44,6 +44,10 @@ if __name__ == "__main__": wer += edit_distance(predict, label) / len(label) step += 1 print(f"predict {transcription}, labels {labels}") + + # reset dataset and dataloader + aishell_dataset = AishellDataset(args.audio_dir, args.trans_file, sampling_rate) + data_loader = DataLoader(aishell_dataset, batch_size=args.bs) t1 = time.time() for inp, labels in data_loader: input_features = processor(inp, sampling_rate=sampling_rate, return_tensors="pt").input_features -- Gitee From 577bd3ae13a266994d37440ed0abfeac88f9e6e3 Mon Sep 17 00:00:00 2001 From: tiansongxue Date: Tue, 29 Oct 2024 09:30:54 +0800 Subject: [PATCH 27/44] add whisperx/modeling_whisper --- .../audio/whisperX/compile_whisper.py | 24 + .../audio/whisperX/modeling_whisper.py | 925 ++++++++++++++++++ 2 files changed, 949 insertions(+) create mode 100644 MindIE/MindIE-Torch/built-in/audio/whisperX/compile_whisper.py create mode 100644 MindIE/MindIE-Torch/built-in/audio/whisperX/modeling_whisper.py diff --git a/MindIE/MindIE-Torch/built-in/audio/whisperX/compile_whisper.py b/MindIE/MindIE-Torch/built-in/audio/whisperX/compile_whisper.py new file mode 100644 index 0000000000..88cbf73589 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/audio/whisperX/compile_whisper.py @@ -0,0 +1,24 @@ +import argparse +import time +from transformers import WhisperProcessor +from nltk.metrics.distance import edit_distance + +import mindietorch +from modeling_whisper import MindieWhisperForConditionalGeneration + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('-model_path', type=str, required=True, help="please provide model path.") + parser.add_argument('-audio_dir', type=str, required=True, help="please provide speech samples dir.") + parser.add_argument('-trans_file', type=str, required=True, help="please provide label file.") + parser.add_argument('-bs', type=int, default=8, help="please provide batch_size, default:8.") + parser.add_argument('-soc_version', type=str, choices=["Ascend310P3", "Ascend910B4"], default="Ascend310P3", + help="please provide soc_version, default:Ascend310P3.") + parser.add_argument('-save_path', type=str, default="compiled_models", help="compiled models save dir.") + parser.add_argument('-device_id', type=int, default=0) + + args = parser.parse_args() + device = f"npu:{args.device_id}" + mindie_whisper = MindieWhisperForConditionalGeneration.from_pretrained(args.model_path) + mindie_whisper.compile(args) \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/audio/whisperX/modeling_whisper.py b/MindIE/MindIE-Torch/built-in/audio/whisperX/modeling_whisper.py new file mode 100644 index 0000000000..249fc1f9f5 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/audio/whisperX/modeling_whisper.py @@ -0,0 +1,925 @@ +import argparse +import os +import copy +import math +import warnings +import librosa +from typing import Optional, Tuple, Union, List, Dict +from collections import OrderedDict + +import torch +from torch.utils.data import Dataset +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn +from transformers import WhisperForConditionalGeneration +from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask +from transformers.models.whisper.modeling_whisper import WhisperDecoder, WhisperEncoder, WhisperEncoderLayer, \ + WhisperModel, WhisperDecoderLayer, WhisperAttention +from transformers.generation.logits_process import LogitsProcessorList +from transformers.generation.stopping_criteria import StoppingCriteriaList, validate_stopping_criteria +from transformers.generation.utils import GenerationMixin +import mindietorch +from mindietorch._enums import dtype + + + +class MindiePFA(WhisperAttention): + + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: float = 0,0, + is_decoder: bool = False, + bias: bool = True, + is_causal: bool = False, + config: bool = None + ): + + super().__init__( + embed_dim, + num_heads, + dropout, + bias, + is_causal, + config) + + # BSND + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return torch.view(bsz, seq_len, self.num_heads, self.head_dim).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + **kwargs + + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + # cross_attn + is_cross_attn = key_value_states is not None + bsz, tgt_len, _ = hidden_states.size() + query_states = self.q_proj(hidden_states) + if ( + is_cross_attn + and past_key_value is not None + # past_key_value layout is BSND + and past_key_value[0].shape[1] == key_value_states.shape[1] + ): + # reuse k, v + key_states = past_key_value[0] + value_states = past_key_value[1] + elif is_cross_attn: + key_states = self._shape(self.k_proj(key_value_states), -1, bsz) + value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + # self attn + elif past_key_value is not None: + # reuse k v + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = torch.cat([past_key_value[0], key_states], dim=1) + value_states = torch.cat([past_key_value[1], value_states], dim=1) + else: + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + + if self.is_decoder: + past_key_value = (key_states.to(torch.float16), value_states.to(torch.float16)) + + # B S N D + query_states = query_states.view(bsz, tgt_len, self.num_heads, self.head_dim).contiguous() + + # B S N D + attn_output = torch.ops.aie.flash_attention( + query=query_states, + key=key_states, + value=value_states, + num_head=self.num_heads, + scale=self.scaling, + layout="BSND", + type="PFA" + ) + attn_output = attn_output.view(bsz, tgt_len, self.num_heads, self.head_dim) + attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + attn_output = self.out_proj(attn_output) + return attn_output, _, past_key_value + +class MindieIFA(WhisperAttention): + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: float = 0.0, + is_decoder: bool = False, + bias: bool = True, + is_causal: bool = True, + config=None + ): + super().__init__( + embed_dim, + num_heads, + dropout, + is_decoder, + bias, + is_causal, + config + ) + + # BSND + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return torch.view(bsz, seq_len, self.num_heads, self.head_dim).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + past_key_value: torch.Tensor, + actual_seq_len: torch.Tensor, + **kwargs + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + # self attn + assert past_key_value is not None + bsz, tgt_len, _ = hidden_states.size() + assert tgt_len == 1 + query_states = self.q_proj(hidden_states) + + key_states = self._shape(self.k_proj(hidden_states), -1, bsz).to(torch.float16) + values_states = self._shape(self.v_proj(hidden_states), -1, bsz).to(torch.float16) + + past_key_cache, past_value_cache = past_key_value[0], past_key_value[1] + + indices = actual_seq_len - 1 + past_key_cache = torch.ops.aie.scatter_update(past_key_cache, indices, key_states, axis=1) + past_value_cache = torch.ops.aie.scatter_update(past_value_cache, indices, values_states, axis=1) + + past_key_value = (past_key_cache, past_value_cache) + # B S N D + query_states = query_states.view(bsz, tgt_len, self.num_heads, self.head_dim).contiguous() + + # B S N D + attn_output = torch.ops.aie.incre_flash_attention( + query=query_states, + key=past_key_cache, + value=past_value_cache, + actual_seq_len=actual_seq_len, + num_heads=self.num_heads, + scale=self.scaling, + layout="BSND") + + attn_output = attn_output.view(bsz, tgt_len, self.num_heads, self.head_dim) + attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + attn_output = self.out_proj(attn_output) + return attn_output, _, past_key_value + + +class MindieAttention(WhisperAttention): + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: float = 0.0, + is_decoder: bool = False, + bias: bool = True, + is_causal: bool = False, + config=None + ): + super().__init__( + embed_dim, + num_heads, + dropout, + is_decoder, + bias, + is_causal, + config + ) + + # BSND + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + **kwargs + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, tgt_len, _ = hidden_states.size() + query_states = self.q_proj(hidden_states) + + # self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + + # B S N D + query_states = query_states.view(bsz, tgt_len, self.num_heads, self.head_dim).contiguous() + + # B S N D + attn_output = torch.ops.aie.flash_attention( + query=query_states, + key=key_states, + value=value_states, + num_head=self.num_heads, + scale=self.scaling, + layout="BSND", + type="PFA" + ) + + attn_output = attn_output.view(bsz, tgt_len, self.num_heads, self.head_dim) + attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + attn_output = self.out_proj(attn_output) + return attn_output, _, _ + + +class MindieWhisperDecoderLayer(WhisperDecoderLayer): + + def __init__(self, config): + super().__init__(config) + self.embed_dim = config.d_model + if config.is_use_ifa: + self.self_attn = MindieIFA( + embed_dim=self.embed_dim, + num_heads=config.decoder_attention_heads, + config=config, + is_decoder=True + ) + else: + self.self_attn = MindiePFA( + embed_dim=self.embed_dim, + num_heads=config.decoder_attention_heads, + config=config, + is_decoder=True + ) + slef.encoder_attn = MindiePFA( + embed_dim=self.embed_dim, + num_heads=config.decoder_attention_heads, + is_decoder=True, + config=config + ) + + def forward( + self, + hidden_states, + attention_mask, + encoder_hidden_states, + layer_head_mask, + cross_attn_layer_head_mask, + past_key_value, + cross_attn_past_key_value, + output_attentions, + use_cache, + actual_seq_len=None, + encoder_attention_mask=None + ) -> torch.Tensor: + residual = hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + + # Self Attention + self_attn_past_key_value = past_key_value if past_key_value is not None else None + hidden_states, _, present_key_value = self.self_attn( + hidden_states=hidden_states, + past_key_value=self_attn_past_key_value, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + output_attentions=output_attentions, + actual_seq_len=actual_seq_len + ) + hidden_states = residual + hidden_states + + # Cross-Attention Block + cross_attn_present_key_value = None + if encoder_hidden_states is not None: + residual = hidden_states + hidden_states = self.encoder_attn_layer_norm(hidden_states) + + hidden_states, _, cross_attn_present_key_value = self.encoder_attn( + hidden_states=hidden_states, + key_value_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + layer_head_mask=cross_attn_layer_head_mask, + past_key_value=cross_attn_past_key_value, + output_attentions=output_attentions, + ) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + + # add cross-attn to positions 3,4 of present_key_value tuple + present_key_value = present_key_value + cross_attn_present_key_value + + # Fully Connected + residual = hidden_states + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = hidden_states.reshape(-1, self.embed_dim) + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = self.fc2(hidden_states) + hidden_states = hidden_states.reshape(-1, 1, self.embed_dim) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +class MindieWhisperDecoder(WhisperDecoder): + + def __init__(self, config): + super().__init__(config) + self.layers = nn.ModuleList([MindieWhisperDecoderLayer(config) for _ in range(config.decoder_layers)]) + self.config = config + + def forward( + self, + input_ids, + encoder_hidden_states, + past_key_values, + actual_seq_len, + use_cache=True, + attention_mask=None): + if input_ids is None: + raise ValueError("You have to specify either decoder_input_ids") + inputs_embeds = self.embed_tokens(input_ids) + + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + # B S N D + past_key_values_length = past_key_values[0].shape[1] if past_key_values is not None else 0 + + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, input_shape, inputs_embeds, past_key_values_length + ) + inputs_embeds = self.embed_tokens(input_ids) + # embed positions + positions = self.embed_positions(input_ids, past_key_values_length=past_key_values_length) + hidden_states = inputs_embeds + positions + + past_key_value_cache = [] + for idx, decoder_layer in enumerate(self.layers): + past_key_value = (past_key_values[4 * idx], past_key_values[4 * idx + 1]) \ + if past_key_values is not None else None + cross_past_key_value = (past_key_values[4 * idx + 2], past_key_values[4 * idx + 3]) \ + if past_key_values is not None else None + + layer_outputs = decoder_layer( + hidden_states, + actual_seq_len=actual_seq_len, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + layer_head_mask=None, + cross_attn_layer_head_mask=None, + past_key_value=past_key_value, + cross_attn_past_key_value=cross_past_key_value, + output_attentions=None, + use_cache=use_cache) + + hidden_states = layer_outputs[0] + past_key_value_cache.extend(layer_outputs[1]) + + hidden_states = self.layer_norm(hidden_states) + return hidden_states, past_key_value_cache + + +class MindieWhisperEncoderLayer(WhisperEncoderLayer): + def __init__(self, config): + super().__init__(config) + self.self_attn = MindiePFA( + embed_dim=self.embed_dim, + num_heads=config.encoder_attention_heads, + dropout=config.attention_dropout, + config=config, + ) + + +class MindieWhisperEncoder(WhisperEncoder): + + def __init__(self, config): + super().__init__(config) + self.layers = nn.ModuleList([MindieWhisperEncoderLayer(config) for _ in range(config.encoder_layers)]) + + +class MindieWhisperModel(WhisperModel): + + def __init__(self, config): + super().__init__(config) + self.decoder = MindieWhisperDecoder(config) + self.encoder = MindieWhisperEncoder(config) + + def forward( + self, + encoder_outputs, + decoder_input_ids, + past_key_values, + actual_seq_len, + use_cache: Optional[bool] = False, + return_dict: Optional[bool] = True, + input_features: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None + ) -> List[torch.Tensor]: + if input_features is None and encoder_outputs is None: + raise ValueError("Make sure to provide either `input_features` or `encoder_outputs` to `forward`.") + + if encoder_outputs is None: + input_features = self._mask_input_features(input_features, attention_mask=attention_mask) + + encoder_outputs = self.encoder( + input_features, + head_mask=None, + output_attentions=None, + output_hidden_states=None, + return_dict=return_dict, + ) + + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + encoder_hidden_states=encoder_outputs, + past_key_values=past_key_values, + use_cache=use_cache, + actual_seq_len=actual_seq_len + ) + return decoder_outputs + + +class MindieWhisperForConditionalGeneration(WhisperForConditionalGeneration, GenerationMixin): + + def __init__(self, config, enable_incre_flash_attention): + super().__init__(config) + config.is_use_ifa = enable_incre_flash_attention + self.model = MindieWhisperModel(config) + self.has_load = False + self.has_compile = False + self.mindie_encoder = None + self.mindie_decoder_prefill = None + self.mindie_decoder = None + self.self_attn_scatter = None + self.encoder_attn_scatter = None + self.save_path = None + self.encoder_seq_len = 1500 + self.file_prefix_names = ["mindie_whisper_encoder_bs", + "mindie_decoder_prefill_bs", + "mindie_whisper_decoder_bs", + "mindie_self_scatter_bs", + "mindie_encoder_scatter_bs" + ] + self.hidden_size = 1280 + self.head_num = 20 + self.head_size = 64 + self.seq_len = 1 + self.layer_nums = 32 + self.max_len = 448 + self.past_key_value = [] + + + + + + def forward(self, *args): + if len(args) not in (2, 131): + raise ValueError(f"The args length of forward can only be 2 or 131, but got {len(args)}") + decoder_input_ids = args[0] + encoder_outputs = args[1] + if len(args) == 131: + actual_seq_len = args[2] + past_key_values = args[3:] + outputs = self.model( + decoder_input_ids=decoder_input_ids, + encoder_outputs=encoder_outputs, + past_key_values=past_key_values, + actual_seq_len=actual_seq_len, + use_cache=True, + return_dict=False, + input_features=None + ) + lm_logits = self.proj_out(outputs[0]) + return [lm_logits] + outputs[1] + + def compile_encoder(self, args): + encoder = self.get_encoder() + class Encoder(torch.nn.Module): + + def __init__(self, model): + super().__init__() + self.model = model + + def forward(self, input_features): + return self.model(input_features=input_features, return_dict=False) + + mel_feature_size = 128 + max_frames = 3000 + input_features = torch.randn([args.bs, mel_feature_size, max_frames]) + encoder_traced = torch.jit.trace(Encoder(encoder), (input_features)) + input_info = [mindietorch.Input(shape=(args.bs, mel_feature_size, max_frames))] + print("Start AOE optimization_level 1.") + compiled = mindietorch.compile(encoder_traced, + inputs=input_info, + precision_policy=mindietorch.PrecisionPolicy.FP16, + soc_version=args.soc_version, + optimization_level=1 + ) + print("AOE optimize finished and start compile encoder.") + compiled = mindietorch.compile(encoder_traced, + inputs=input_info, + precision_policy=mindietorch.PrecisionPolicy.FP16, + soc_version=args.soc_version, + ) + save_file = os.path.join(save_path, f"{self.file_prefix_names[0]}{args.bs}.ts") + torch.jit.save(compiled, save_file) + + traced_file = os.path.join(save_path, f"{self.file_prefix_names[0]}{args.bs}.pt") + torch.jit.save(encoder_traced, traced_file) + print(f"Compile encoder success, saved in {save_file}") + + def compile_prefill_decoder(self, args): + print("Start compiling prefill_decoder.") + encoder_seq_len = 1500 + hidden_size = 1280 + encoder_outputs = torch.randn([args.bs, encoder_seq_len, hidden_size]) + decoder_input_ids = torch.randint(1, 4, (args.bs, 1)) + prefill_decoder_traced = torch.jit.trace(self.eval(), (decoder_input_ids, encoder_outputs)) + input_info = [mindietorch.Input(shape=(args.bs, 1), dtype=dtype.INT64), + mindietorch.Input(shape=(args.bs, encoder_seq_len, hidden_size))] + prefill_decoder_compiled = mindietorch.compile(prefill_decoder_traced, + inputs=input_info, + precision_policy=mindietorch.PrecisionPolicy.FP16, + soc_version=args.soc_version) + + traced_file = os.path.join(save_path, f"{self.file_prefix_names[1]}{args.bs}.pt") + torch.jit.save(prefill_decoder_traced, traced_file) + + save_file = os.path.join(save_path, f"{self.file_prefix_names[1]}{args.bs}.ts") + torch.jit.save(prefill_decoder_compiled, save_file) + print(f"Compile prefill_decoder success, saved in {save_file}.") + + def compile_decoder(self, args): + + class Decoder(torch.nn.Module): + + def __init__(self, model): + super().__init__() + self.model = model + + def forward(self, *args): + return self.model.forward(args)[0] + + mindie_whisper = MindieWhisperForConditionalGeneration.from_pretrained(args.model_path, + enable_incre_flash_attention=True) + decoder = Decoder(model) + print("Start compiling decoder.") + + + encoder_outputs = torch.randn([args.bs, self.encoder_seq_len, hidden_size]) + decoder_input_ids = torch.randint(1, 4, (args.bs, 1)) + actual_seq_len = torch.ones((args.bs, 1)) + all_past_key_value = [torch.randn([args.bs, self.seq_len, self.head_num, self.head_size]), + torch.randn([args.bs, self.seq_len, self.head_num, self.head_size]), + torch.randn([args.bs, self.encoder_seq_len, self.head_num, self.head_size]), + torch.randn([args.bs, self.encoder_seq_len, self.head_num, self.head_size])] * layer_nums + traced_args = [decoder_input_ids, encoder_outputs, actual_seq_len] + traced_args.extend(all_past_key_value) + traced_decoder = torch.jit.trace(decoder, traced_args) + # BSND + key_value_infos = [ + mindietorch.Input(shape=(args.bs, max_len, head_num, head_size), + dtype=dtype.FLOAT16), + mindietorch.Input(shape=(args.bs, max_len, head_num, head_size), + dtype=dtype.FLOAT16), + mindietorch.Input(shape=(args.bs, head_num, self.encoder_seq_len, head_size), + dtype=dtype.FLOAT16), + mindietorch.Input(shape=(args.bs, head_num, self.encoder_seq_len, head_size), + dtype=dtype.FLOAT16 + )] * layer_nums + input_info = [mindietorch.Input(shape=(args.bs, 1), dtype=dtype.INT64), + mindietorch.Input(shape=(args.bs, self.encoder_seq_len, hidden_size)), + mindietorch.Input(shape=(args.bs, 1), dtype=dtype.INT64)] + + input_info.extend(key_value_infos) + float_size = 4 + voc_size = 51866 + buffer_size = math.ceil((args.bs * 1 * voc_size * float_size) / 1024 / 1024) + print(f"Set {buffer_size}/MB for output.") + compiled_decoder = mindietorch.compile(traced_decoder, + inputs=input_info, + precision_policy=mindietorch.PrecisionPolicy.FP16, + soc_version=args.soc_version, + default_buffer_size_vec=[buffer_size]) + + traced_file = os.path.join(save_path, f"{self.file_prefix_names[2]}{args.bs}.pt") + torch.jit.save(traced_decoder, traced_file) + + save_file = os.path.join(save_path, f"{self.file_prefix_names[2]}{args.bs}.ts") + torch.jit.save(compiled_decoder, save_file) + + print(f"Compile whisper_decoder success, saved in {save_file}.") + + def compile_scatter_update(self, args): + + class MindieScatter(torch.nn.Module): + def forward(self, past_key_value, indices, update_states): + out = torch.ops.aie.scatter_update(self, past_key_value, indices, update_states, 1) + return out + bs = args.bs + self_past_key_value = torch.randn([bs, self.max_len, self.head_num, self.head_size]) + encoder_past_key_value = torch.rand([bs, self.max_len, self.head_num, self.head_size]) + indices = torch.tensor([0]*bs) + update_states = torch.randn([bs, 1, 20, 64]) + traced = torch.jit.trace(MindieScatter(), (self_past_key_value, indices, update_states)) + + self_attn_info = mindietorch.Input(shape=self_past_key_value.shape, dtype=mindietorch.dtype.FLOAT16) + encoder_attn_info = mindietorch.Input(shape=encoder_past_key_value.shape, dtype=mindietorch.dtype.FLOAT16) + indices_info = mindietorch.Input(shape=indices.shape, dtype=mindietorch.dtype.INT64) + update_states_info = mindietorch.Input(shape=encoder_past_key_value.shape, dtype=mindietorch.dtype.FLOAT16) + + compile_self = mindietorch.compile(traced, inputs=[self_attn_info, indices_info, update_states_info], + soc_version=args.soc_version) + torch.jit.save(compile_self, f"{self.file_prefix_names[3]}{args.bs}.ts") + + compile_self = mindietorch.compile(traced, inputs=[encoder_attn_info, indices_info, update_states_info], + soc_version=args.soc_version) + torch.jit.save(compile_self, f"{self.file_prefix_names[4]}{args.bs}.ts") + print("compile scatter success.") + + def compile(self, args): + print("Start compiling Mindie-Whisper, it will take some time, please wait.") + save_path = args.save_path + model_path = args.model_path + + if not save_path: + raise ValueError("Please provide the directory where the compiled model saved.") + if not os.path.exists(save_path): + os.makedirs(save_path) + print(f"Directory {save_path} created.") + else: + print(f"Directory {save_path} already exists.") + self.compile_encoder(args) + self.compile_prefill_decoder(args) + self.compile_decoder(args) + self.compile_scatter_update(args) + self.has_compile = True + + def _init_past_key_value_cache(self, bsz): + for _ in range(32): + self.past_key_value.append(torch.ones([bsz, self.max_len, self.head_num, self.head_size])) + self.past_key_value.append(torch.ones([bsz, self.max_len, self.head_num, self.head_size])) + self.past_key_value.append(torch.ones([bsz, self.encoder_seq_len, self.head_num, self.head_size])) + self.past_key_value.append(torch.ones([bsz, self.encoder_seq_len, self.head_num, self.head_size])) + print("init past key value cache success.") + + def load_mindie_models(self, save_path, batch_size): + if not (save_path and batch_size): + raise ValueError(f"Please provide batch_size and the directory where the compiled models saved,\ + but found save_path is {save_path}, batch_size is{batch_size}.") + self._check_save_path(save_path, batch_size) + self.batch_size = batch_size + + self._init_past_key_value_cache(batch_size) + + + if not self.has_load: + self.mindie_encoder = torch.jit.load(f"{save_path}/{self.file_prefix_names[0]}{batch_size}.ts") + print(f"load {self.file_prefix_names[0]}{batch_size}.ts success.") + + self.mindie_decoder_prefill = torch.jit.load(f"{save_path}/{self.file_prefix_names[1]}{batch_size}.ts") + print(f"load {self.file_prefix_names[1]}{batch_size}.ts success.") + + self.mindie_decoder = torch.jit.load(f"{save_path}/{self.file_prefix_names[2]}{batch_size}.ts") + print(f"load {self.file_prefix_names[2]}{batch_size}.ts success.") + + self.self_attn_scatter = torch.jit.load(f"{save_path}/{self.file_prefix_names[3]}{batch_size}.ts") + print(f"load {self.file_prefix_names[3]}{batch_size}.ts success.") + + self.self_attn_scatter = torch.jit.load(f"{save_path}/{self.file_prefix_names[4]}{batch_size}.ts") + print(f"load {self.file_prefix_names[4]}{batch_size}.ts success.") + self.has_load = True + else: + print("Mindie whisper has already load.") + + def greedy_search( + self, + input_ids: torch.LongTensor, + logits_processor: Optional[LogitsProcessorList] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, + max_length: Optional[int] = None, + pad_token_id: Optional[int] = None, + eos_token_id: Optional[Union[int, List[int]]] = None, + output_scores: Optional[bool] = None, + return_dict_in_generate: Optional[bool] = None, + synced_gpus: bool = False, + **model_kwargs): + logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() + stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() + if max_length is not None: + print( + "`max_length` is deprecated in this function, use" + " `stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])` instead." + ) + stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) + pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id + eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id + if isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] + eos_token_id_tensor = torch.tensor(eos_token_id).to("cpu") if eos_token_id is not None else None + output_scores = output_scores if output_scores is not None else self.generation_config.output_scores + # init attention / hidden states / scores tuples + scores = () if (return_dict_in_generate and output_scores) else None + + # keep track of which sequences are already finished + unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device="cpu") + + this_peer_finished = False # used by synced_gpus only + kv_actual_step = 1 + indices = torch.tensor([0] * input_ids.shape[0]) + is_first_step = True + while True: + + model_inputs = self.prepare_inputs_for_generation(input_ids, is_first_step, **model_kwargs) + args = [model_inputs["decoder_input_ids"].contiguous().to("npu"), model_inputs["encoder_outputs"]] + if is_first_step: + outputs = self.mindie_decoder_prefill(*args) + for idx in range(32): + self.self_attn_scatter(self.past_key_value[0], indices, outputs[1 + 4*idx]) + self.self_attn_scatter(self.past_key_value[1], indices, outputs[1 + 4*idx + 1]) + self.encoder_attn_scatter(self.past_key_value[2], indices, outputs[1 + 4*idx + 2]) + self.encoder_attn_scatter(self.past_key_value[3], indices, outputs[1 + 4*idx + 3]) + is_first_step = False + else: + kv_actual_step += 1 + args.append(torch.tensor([kv_actual_step] * input_ids.shape[0])).to("npu") + args.extend(self.past_key_value) + outputs = self.mindie_decoder(*args) + if isinstacnce(outputs, list): + next_token_logits = outputs[0].to("cpu")[:, -1, :] + else: + next_token_logits = outputs.to("cpu")[:, -1, :] + + # pre-process distribution + next_tokens_scores = logits_processor(input_ids.to("cpu"), next_token_logits) + + # argmax + next_tokens = torch.argmax(next_tokens_scores, dim=-1) + + # finished sentences should have their next token be a padding token + if eos_token_id is not None: + if pad_token_id is None: + raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.") + next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) + + # update generated ids, model inputs, and length for next step + input_ids = torch.cat([input_ids.to("cpu"), next_tokens[:, None]], dim=-1) + + model_kwargs = self._update_model_kwargs_for_generation( + outputs, model_kwargs + ) + + # if eos_token was found in one sentence, set sentence to finished + if eos_token_id_tensor is not None: + unfinished_sequences = unfinished_sequences.mul( + next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0) + ) + + # stop when each sentence is finished + if unfinished_sequences.max() == 0: + this_peer_finished = True + + # stop if we exceed the maximum length + if stopping_criteria(input_ids, scores): + this_peer_finished = True + + if this_peer_finished and not synced_gpus: + break + + return input_ids + + def generate( + self, + input_features: Optional[torch.Tensor] = None, + generation_config=None, + logits_processor=None, + stopping_criteria=None, + prefix_allowed_tokens_fn=None, + synced_gpus=False, + return_dict_in_generate: Optional[bool] = None, + **kwargs, + ): + if generation_config is None: + generation_config = copy.deepcopy(self.generation_config) + num_segment_frames = 3000 + assert input_features.shape[-1] == num_segment_frames, "Only support 30s speech." + encoder_outputs = self.mindie_encoder(input_features.to("npu"))[0] + kwargs["encoder_outputs"] = encoder_outputs + outputs = GenerationMixin.generate( + self, + input_features, + generation_config, + logits_processor, + stopping_criteria, + prefix_allowed_tokens_fn, + synced_gpus, + return_dict_in_generate=return_dict_in_generate, + **kwargs + ) + return outputs + + def _update_model_kwargs_for_generation( + self, + outputs, + model_inputs, + ): + return model_inputs + + def _maybe_initialize_input_ids_for_generation( + self, + inputs: Optional[torch.Tensor] = None, + bos_token_id: Optional[int] = None, + model_kwargs: Optional[Dict[str, torch.Tensor]] = None, + ) -> torch.LongTensor: + """Initializes input ids for generation, if necessary.""" + if inputs is not None: + return inputs + + encoder_outputs = model_kwargs.get("encoder_outputs") + if self.config.is_encoder_decoder and encoder_outputs is not None: + # make dummy input_ids with value -100, as a sanity check ensuring that they won't be used for encoding + shape = encoder_outputs.size()[:-1] + input_ids = torch.ones(shape, dtype=torch.long, device="cpu") * -100 + return input_ids.to("npu") + + if bos_token_id is None: + raise ValueError("`bos_token_id` has to be defined when no `input_ids` are provided.") + + # If there is some tensor in `model_kwargs`, we can infer the batch size from it. This is helpful with + # soft-prompting or in multimodal implementations built on top of decoder-only language models. + batch_size = 1 + for value in model_kwargs.values(): + if isinstance(value, torch.Tensor): + batch_size = value.shape[0] + break + return torch.ones((batch_size, 1), dtype=torch.long, device=self.device) * bos_token_id + + def _prepare_decoder_input_ids_for_generation( + self, + batch_size: int, + model_input_name: str, + model_kwargs: Dict[str, torch.Tensor], + decoder_start_token_id: int = None, + bos_token_id: int = None, + device: torch.device = None, + ) -> Tuple[torch.LongTensor, Dict[str, torch.Tensor]]: + """Prepares `decoder_input_ids` for generation with encoder-decoder models""" + # 1. Check whether the user has defined `decoder_input_ids` manually. To facilitate in terms of input naming, + # we also allow the user to pass it under `input_ids`, if the encoder does not use it as the main input. + if model_kwargs is not None and "decoder_input_ids" in model_kwargs: + decoder_input_ids = model_kwargs.pop("decoder_input_ids") + elif "input_ids" in model_kwargs and model_input_name != "input_ids": + decoder_input_ids = model_kwargs.pop("input_ids") + else: + decoder_input_ids = None + + # 2. Encoder-decoder models expect the `decoder_input_ids` to start with a special token. Let's ensure that. + decoder_start_token_id = self._get_decoder_start_token_id(decoder_start_token_id, bos_token_id) + if device is None: + device = self.device + decoder_input_ids_start = torch.ones((batch_size, 1), dtype=torch.long, device="cpu") * decoder_start_token_id + + # no user input -> use decoder_start_token_id as decoder_input_ids + if decoder_input_ids is None: + decoder_input_ids = decoder_input_ids_start + # exception: Donut checkpoints have task-specific decoder starts and don't expect a BOS token + elif self.config.model_type == "vision-encoder-decoder" and "donut" in self.name_or_path.lower(): + pass + # user input but doesn't start with decoder_start_token_id -> prepend decoder_start_token_id (and adjust + # decoder_attention_mask if provided) + elif (decoder_input_ids[:, 0] != decoder_start_token_id).all().item(): + decoder_input_ids = torch.cat([decoder_input_ids_start, decoder_input_ids], dim=-1) + if "decoder_attention_mask" in model_kwargs: + decoder_attention_mask = model_kwargs["decoder_attention_mask"] + decoder_attention_mask = torch.cat( + (torch.ones_like(decoder_attention_mask)[:, :1], decoder_attention_mask), + dim=-1, + ) + model_kwargs["decoder_attention_mask"] = decoder_attention_mask.to(self.device) + + return decoder_input_ids.to("npu"), model_kwargs + + def _check_save_path(self, save_path, batch_size): + file_list = os.listdir(save_path) + expected_files = [file + f"{batch_size}.ts" for file in self.file_prefix_names] + for file in expected_files: + if file not in file_list: + raise ValueError(f"Expected file name is {file}, but can't be found in path: {save_path}") + + def prepare_inputs_for_generation( + self, + decoder_input_ids, + if_first_step, + encoder_outputs=None, + **kwargs + ): + if not if_first_step: + decoder_input_ids_shape = decoder_input_ids.shape + remove_prefix_length = decoder_input_ids_shape[1] - 1 + decoder_input_ids = decoder_input_ids[:, remove_prefix_length] + return { + "encoder_outputs": encoder_outputs, + "decoder_input_ids": decoder_input_ids + } -- Gitee From faa0959f5a1560b4fb1e42ad9e3f3a87dcb19d9d Mon Sep 17 00:00:00 2001 From: tiansongxue Date: Tue, 29 Oct 2024 09:43:51 +0800 Subject: [PATCH 28/44] add whisperx/modeling_whisper --- .../audio/whisperX/compile_whisper.py | 7 ------- .../audio/whisperX/modeling_whisper.py | 21 ++++++++++--------- 2 files changed, 11 insertions(+), 17 deletions(-) diff --git a/MindIE/MindIE-Torch/built-in/audio/whisperX/compile_whisper.py b/MindIE/MindIE-Torch/built-in/audio/whisperX/compile_whisper.py index 88cbf73589..2fb937b8ab 100644 --- a/MindIE/MindIE-Torch/built-in/audio/whisperX/compile_whisper.py +++ b/MindIE/MindIE-Torch/built-in/audio/whisperX/compile_whisper.py @@ -1,17 +1,11 @@ import argparse import time -from transformers import WhisperProcessor -from nltk.metrics.distance import edit_distance - import mindietorch -from modeling_whisper import MindieWhisperForConditionalGeneration if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument('-model_path', type=str, required=True, help="please provide model path.") - parser.add_argument('-audio_dir', type=str, required=True, help="please provide speech samples dir.") - parser.add_argument('-trans_file', type=str, required=True, help="please provide label file.") parser.add_argument('-bs', type=int, default=8, help="please provide batch_size, default:8.") parser.add_argument('-soc_version', type=str, choices=["Ascend310P3", "Ascend910B4"], default="Ascend310P3", help="please provide soc_version, default:Ascend310P3.") @@ -20,5 +14,4 @@ if __name__ == "__main__": args = parser.parse_args() device = f"npu:{args.device_id}" - mindie_whisper = MindieWhisperForConditionalGeneration.from_pretrained(args.model_path) mindie_whisper.compile(args) \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/audio/whisperX/modeling_whisper.py b/MindIE/MindIE-Torch/built-in/audio/whisperX/modeling_whisper.py index 249fc1f9f5..78b0bba1ff 100644 --- a/MindIE/MindIE-Torch/built-in/audio/whisperX/modeling_whisper.py +++ b/MindIE/MindIE-Torch/built-in/audio/whisperX/modeling_whisper.py @@ -19,6 +19,7 @@ from transformers.models.whisper.modeling_whisper import WhisperDecoder, Whisper from transformers.generation.logits_process import LogitsProcessorList from transformers.generation.stopping_criteria import StoppingCriteriaList, validate_stopping_criteria from transformers.generation.utils import GenerationMixin +from modeling_whisper import MindieWhisperForConditionalGeneration import mindietorch from mindietorch._enums import dtype @@ -30,7 +31,7 @@ class MindiePFA(WhisperAttention): self, embed_dim: int, num_heads: int, - dropout: float = 0,0, + dropout: float = 0.0, is_decoder: bool = False, bias: bool = True, is_causal: bool = False, @@ -443,7 +444,7 @@ class MindieWhisperModel(WhisperModel): class MindieWhisperForConditionalGeneration(WhisperForConditionalGeneration, GenerationMixin): - def __init__(self, config, enable_incre_flash_attention): + def __init__(self, config, enable_incre_flash_attention=False): super().__init__(config) config.is_use_ifa = enable_incre_flash_attention self.model = MindieWhisperModel(config) @@ -494,8 +495,8 @@ class MindieWhisperForConditionalGeneration(WhisperForConditionalGeneration, Gen lm_logits = self.proj_out(outputs[0]) return [lm_logits] + outputs[1] - def compile_encoder(self, args): - encoder = self.get_encoder() + def compile_encoder(self, args, model): + encoder = model.get_encoder() class Encoder(torch.nn.Module): def __init__(self, model): @@ -530,13 +531,13 @@ class MindieWhisperForConditionalGeneration(WhisperForConditionalGeneration, Gen torch.jit.save(encoder_traced, traced_file) print(f"Compile encoder success, saved in {save_file}") - def compile_prefill_decoder(self, args): + def compile_prefill_decoder(self, args, model): print("Start compiling prefill_decoder.") encoder_seq_len = 1500 hidden_size = 1280 encoder_outputs = torch.randn([args.bs, encoder_seq_len, hidden_size]) decoder_input_ids = torch.randint(1, 4, (args.bs, 1)) - prefill_decoder_traced = torch.jit.trace(self.eval(), (decoder_input_ids, encoder_outputs)) + prefill_decoder_traced = torch.jit.trace(model.eval(), (decoder_input_ids, encoder_outputs)) input_info = [mindietorch.Input(shape=(args.bs, 1), dtype=dtype.INT64), mindietorch.Input(shape=(args.bs, encoder_seq_len, hidden_size))] prefill_decoder_compiled = mindietorch.compile(prefill_decoder_traced, @@ -564,7 +565,7 @@ class MindieWhisperForConditionalGeneration(WhisperForConditionalGeneration, Gen mindie_whisper = MindieWhisperForConditionalGeneration.from_pretrained(args.model_path, enable_incre_flash_attention=True) - decoder = Decoder(model) + decoder = Decoder(mindie_whisper) print("Start compiling decoder.") @@ -642,7 +643,7 @@ class MindieWhisperForConditionalGeneration(WhisperForConditionalGeneration, Gen def compile(self, args): print("Start compiling Mindie-Whisper, it will take some time, please wait.") save_path = args.save_path - model_path = args.model_path + whisper_model = MindieWhisperForConditionalGeneration.from_pretrained(args.model_path) if not save_path: raise ValueError("Please provide the directory where the compiled model saved.") @@ -651,8 +652,8 @@ class MindieWhisperForConditionalGeneration(WhisperForConditionalGeneration, Gen print(f"Directory {save_path} created.") else: print(f"Directory {save_path} already exists.") - self.compile_encoder(args) - self.compile_prefill_decoder(args) + self.compile_encoder(args, whisper_model) + self.compile_prefill_decoder(args, whisper_model) self.compile_decoder(args) self.compile_scatter_update(args) self.has_compile = True -- Gitee From 2371d8a5574d60e6d8f189f407479d6b86990a44 Mon Sep 17 00:00:00 2001 From: tiansongxue Date: Tue, 29 Oct 2024 09:47:47 +0800 Subject: [PATCH 29/44] add whisperx/modeling_whisper22222 --- .../built-in/audio/whisperX/compile_whisper.py | 1 + .../built-in/audio/whisperX/modeling_whisper.py | 11 +++++------ 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/MindIE/MindIE-Torch/built-in/audio/whisperX/compile_whisper.py b/MindIE/MindIE-Torch/built-in/audio/whisperX/compile_whisper.py index 2fb937b8ab..194383f0ac 100644 --- a/MindIE/MindIE-Torch/built-in/audio/whisperX/compile_whisper.py +++ b/MindIE/MindIE-Torch/built-in/audio/whisperX/compile_whisper.py @@ -13,5 +13,6 @@ if __name__ == "__main__": parser.add_argument('-device_id', type=int, default=0) args = parser.parse_args() + mindie_whisper = MindieWhisperForConditionalGeneration.from_pretrained(args.model_path) device = f"npu:{args.device_id}" mindie_whisper.compile(args) \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/audio/whisperX/modeling_whisper.py b/MindIE/MindIE-Torch/built-in/audio/whisperX/modeling_whisper.py index 78b0bba1ff..48fa9e1ce2 100644 --- a/MindIE/MindIE-Torch/built-in/audio/whisperX/modeling_whisper.py +++ b/MindIE/MindIE-Torch/built-in/audio/whisperX/modeling_whisper.py @@ -496,7 +496,7 @@ class MindieWhisperForConditionalGeneration(WhisperForConditionalGeneration, Gen return [lm_logits] + outputs[1] def compile_encoder(self, args, model): - encoder = model.get_encoder() + encoder = self.get_encoder() class Encoder(torch.nn.Module): def __init__(self, model): @@ -531,13 +531,13 @@ class MindieWhisperForConditionalGeneration(WhisperForConditionalGeneration, Gen torch.jit.save(encoder_traced, traced_file) print(f"Compile encoder success, saved in {save_file}") - def compile_prefill_decoder(self, args, model): + def compile_prefill_decoder(self, args): print("Start compiling prefill_decoder.") encoder_seq_len = 1500 hidden_size = 1280 encoder_outputs = torch.randn([args.bs, encoder_seq_len, hidden_size]) decoder_input_ids = torch.randint(1, 4, (args.bs, 1)) - prefill_decoder_traced = torch.jit.trace(model.eval(), (decoder_input_ids, encoder_outputs)) + prefill_decoder_traced = torch.jit.trace(self.eval(), (decoder_input_ids, encoder_outputs)) input_info = [mindietorch.Input(shape=(args.bs, 1), dtype=dtype.INT64), mindietorch.Input(shape=(args.bs, encoder_seq_len, hidden_size))] prefill_decoder_compiled = mindietorch.compile(prefill_decoder_traced, @@ -643,7 +643,6 @@ class MindieWhisperForConditionalGeneration(WhisperForConditionalGeneration, Gen def compile(self, args): print("Start compiling Mindie-Whisper, it will take some time, please wait.") save_path = args.save_path - whisper_model = MindieWhisperForConditionalGeneration.from_pretrained(args.model_path) if not save_path: raise ValueError("Please provide the directory where the compiled model saved.") @@ -652,8 +651,8 @@ class MindieWhisperForConditionalGeneration(WhisperForConditionalGeneration, Gen print(f"Directory {save_path} created.") else: print(f"Directory {save_path} already exists.") - self.compile_encoder(args, whisper_model) - self.compile_prefill_decoder(args, whisper_model) + self.compile_encoder(args) + self.compile_prefill_decoder(args) self.compile_decoder(args) self.compile_scatter_update(args) self.has_compile = True -- Gitee From 2dc7b8729de17fc494cb6b6986fb56c7730b8cf8 Mon Sep 17 00:00:00 2001 From: tiansongxue Date: Tue, 29 Oct 2024 15:19:35 +0800 Subject: [PATCH 30/44] add whisperx/modeling_whisper22222 --- .../audio/whisperX/compile_whisper.py | 1 + .../audio/whisperX/modeling_whisper.py | 25 +++++++++---------- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/MindIE/MindIE-Torch/built-in/audio/whisperX/compile_whisper.py b/MindIE/MindIE-Torch/built-in/audio/whisperX/compile_whisper.py index 194383f0ac..217b42d96e 100644 --- a/MindIE/MindIE-Torch/built-in/audio/whisperX/compile_whisper.py +++ b/MindIE/MindIE-Torch/built-in/audio/whisperX/compile_whisper.py @@ -1,6 +1,7 @@ import argparse import time import mindietorch +from modeling_whisper import MindieWhisperForConditionalGeneration if __name__ == "__main__": diff --git a/MindIE/MindIE-Torch/built-in/audio/whisperX/modeling_whisper.py b/MindIE/MindIE-Torch/built-in/audio/whisperX/modeling_whisper.py index 48fa9e1ce2..2a0efdc42d 100644 --- a/MindIE/MindIE-Torch/built-in/audio/whisperX/modeling_whisper.py +++ b/MindIE/MindIE-Torch/built-in/audio/whisperX/modeling_whisper.py @@ -19,7 +19,6 @@ from transformers.models.whisper.modeling_whisper import WhisperDecoder, Whisper from transformers.generation.logits_process import LogitsProcessorList from transformers.generation.stopping_criteria import StoppingCriteriaList, validate_stopping_criteria from transformers.generation.utils import GenerationMixin -from modeling_whisper import MindieWhisperForConditionalGeneration import mindietorch from mindietorch._enums import dtype @@ -48,7 +47,7 @@ class MindiePFA(WhisperAttention): # BSND def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return torch.view(bsz, seq_len, self.num_heads, self.head_dim).contiguous() + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).contiguous() def forward( self, @@ -129,7 +128,7 @@ class MindieIFA(WhisperAttention): # BSND def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return torch.view(bsz, seq_len, self.num_heads, self.head_dim).contiguous() + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).contiguous() def forward( self, @@ -251,7 +250,7 @@ class MindieWhisperDecoderLayer(WhisperDecoderLayer): config=config, is_decoder=True ) - slef.encoder_attn = MindiePFA( + self.encoder_attn = MindiePFA( embed_dim=self.embed_dim, num_heads=config.decoder_attention_heads, is_decoder=True, @@ -495,7 +494,7 @@ class MindieWhisperForConditionalGeneration(WhisperForConditionalGeneration, Gen lm_logits = self.proj_out(outputs[0]) return [lm_logits] + outputs[1] - def compile_encoder(self, args, model): + def compile_encoder(self, args): encoder = self.get_encoder() class Encoder(torch.nn.Module): @@ -524,10 +523,10 @@ class MindieWhisperForConditionalGeneration(WhisperForConditionalGeneration, Gen precision_policy=mindietorch.PrecisionPolicy.FP16, soc_version=args.soc_version, ) - save_file = os.path.join(save_path, f"{self.file_prefix_names[0]}{args.bs}.ts") + save_file = os.path.join(args.save_path, f"{self.file_prefix_names[0]}{args.bs}.ts") torch.jit.save(compiled, save_file) - traced_file = os.path.join(save_path, f"{self.file_prefix_names[0]}{args.bs}.pt") + traced_file = os.path.join(args.save_path, f"{self.file_prefix_names[0]}{args.bs}.pt") torch.jit.save(encoder_traced, traced_file) print(f"Compile encoder success, saved in {save_file}") @@ -545,10 +544,10 @@ class MindieWhisperForConditionalGeneration(WhisperForConditionalGeneration, Gen precision_policy=mindietorch.PrecisionPolicy.FP16, soc_version=args.soc_version) - traced_file = os.path.join(save_path, f"{self.file_prefix_names[1]}{args.bs}.pt") + traced_file = os.path.join(args.save_path, f"{self.file_prefix_names[1]}{args.bs}.pt") torch.jit.save(prefill_decoder_traced, traced_file) - save_file = os.path.join(save_path, f"{self.file_prefix_names[1]}{args.bs}.ts") + save_file = os.path.join(args.save_path, f"{self.file_prefix_names[1]}{args.bs}.ts") torch.jit.save(prefill_decoder_compiled, save_file) print(f"Compile prefill_decoder success, saved in {save_file}.") @@ -605,10 +604,10 @@ class MindieWhisperForConditionalGeneration(WhisperForConditionalGeneration, Gen soc_version=args.soc_version, default_buffer_size_vec=[buffer_size]) - traced_file = os.path.join(save_path, f"{self.file_prefix_names[2]}{args.bs}.pt") + traced_file = os.path.join(args.save_path, f"{self.file_prefix_names[2]}{args.bs}.pt") torch.jit.save(traced_decoder, traced_file) - save_file = os.path.join(save_path, f"{self.file_prefix_names[2]}{args.bs}.ts") + save_file = os.path.join(args.save_path, f"{self.file_prefix_names[2]}{args.bs}.ts") torch.jit.save(compiled_decoder, save_file) print(f"Compile whisper_decoder success, saved in {save_file}.") @@ -633,11 +632,11 @@ class MindieWhisperForConditionalGeneration(WhisperForConditionalGeneration, Gen compile_self = mindietorch.compile(traced, inputs=[self_attn_info, indices_info, update_states_info], soc_version=args.soc_version) - torch.jit.save(compile_self, f"{self.file_prefix_names[3]}{args.bs}.ts") + torch.jit.save(compile_self, f"{args.save_path}/{self.file_prefix_names[3]}{args.bs}.ts") compile_self = mindietorch.compile(traced, inputs=[encoder_attn_info, indices_info, update_states_info], soc_version=args.soc_version) - torch.jit.save(compile_self, f"{self.file_prefix_names[4]}{args.bs}.ts") + torch.jit.save(compile_self, f"{args.save_path}/{self.file_prefix_names[4]}{args.bs}.ts") print("compile scatter success.") def compile(self, args): -- Gitee From 5bcec66e2f67792c484fdf126b297e07e3429aef Mon Sep 17 00:00:00 2001 From: tiansongxue Date: Tue, 29 Oct 2024 16:11:41 +0800 Subject: [PATCH 31/44] fix11111 --- .../built-in/audio/whisperX/modeling_whisper.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/MindIE/MindIE-Torch/built-in/audio/whisperX/modeling_whisper.py b/MindIE/MindIE-Torch/built-in/audio/whisperX/modeling_whisper.py index 2a0efdc42d..092ee80a39 100644 --- a/MindIE/MindIE-Torch/built-in/audio/whisperX/modeling_whisper.py +++ b/MindIE/MindIE-Torch/built-in/audio/whisperX/modeling_whisper.py @@ -410,8 +410,8 @@ class MindieWhisperModel(WhisperModel): self, encoder_outputs, decoder_input_ids, - past_key_values, - actual_seq_len, + past_key_values: Optional[torch.Tensor] = None, + actual_seq_len: Optional[torch.Tensor] = None, use_cache: Optional[bool] = False, return_dict: Optional[bool] = True, input_features: Optional[torch.FloatTensor] = None, @@ -470,10 +470,6 @@ class MindieWhisperForConditionalGeneration(WhisperForConditionalGeneration, Gen self.max_len = 448 self.past_key_value = [] - - - - def forward(self, *args): if len(args) not in (2, 131): raise ValueError(f"The args length of forward can only be 2 or 131, but got {len(args)}") @@ -482,6 +478,9 @@ class MindieWhisperForConditionalGeneration(WhisperForConditionalGeneration, Gen if len(args) == 131: actual_seq_len = args[2] past_key_values = args[3:] + else: + actual_seq_len = None + past_key_values = None outputs = self.model( decoder_input_ids=decoder_input_ids, encoder_outputs=encoder_outputs, -- Gitee From 078509db6fea91360c1fa0d8082a829ae6bd1a4c Mon Sep 17 00:00:00 2001 From: tiansongxue Date: Tue, 29 Oct 2024 16:34:09 +0800 Subject: [PATCH 32/44] fix22222 --- .../built-in/audio/whisperX/modeling_whisper.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/MindIE/MindIE-Torch/built-in/audio/whisperX/modeling_whisper.py b/MindIE/MindIE-Torch/built-in/audio/whisperX/modeling_whisper.py index 092ee80a39..bc044a2490 100644 --- a/MindIE/MindIE-Torch/built-in/audio/whisperX/modeling_whisper.py +++ b/MindIE/MindIE-Torch/built-in/audio/whisperX/modeling_whisper.py @@ -161,8 +161,8 @@ class MindieIFA(WhisperAttention): query=query_states, key=past_key_cache, value=past_value_cache, - actual_seq_len=actual_seq_len, - num_heads=self.num_heads, + actual_seq_lengths=actual_seq_len, + num_head=self.num_heads, scale=self.scaling, layout="BSND") @@ -559,15 +559,15 @@ class MindieWhisperForConditionalGeneration(WhisperForConditionalGeneration, Gen self.model = model def forward(self, *args): - return self.model.forward(args)[0] + return self.model.forward(*args)[0] mindie_whisper = MindieWhisperForConditionalGeneration.from_pretrained(args.model_path, enable_incre_flash_attention=True) decoder = Decoder(mindie_whisper) print("Start compiling decoder.") - - encoder_outputs = torch.randn([args.bs, self.encoder_seq_len, hidden_size]) + layer_nums = 32 + encoder_outputs = torch.randn([args.bs, self.encoder_seq_len, self.hidden_size]) decoder_input_ids = torch.randint(1, 4, (args.bs, 1)) actual_seq_len = torch.ones((args.bs, 1)) all_past_key_value = [torch.randn([args.bs, self.seq_len, self.head_num, self.head_size]), -- Gitee From a049f3f62461531874684ffb14a46383d9ef503a Mon Sep 17 00:00:00 2001 From: tiansongxue Date: Tue, 29 Oct 2024 16:36:13 +0800 Subject: [PATCH 33/44] fix22222 --- .../audio/whisperX/modeling_whisper.py | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/MindIE/MindIE-Torch/built-in/audio/whisperX/modeling_whisper.py b/MindIE/MindIE-Torch/built-in/audio/whisperX/modeling_whisper.py index bc044a2490..e2ff01ffa3 100644 --- a/MindIE/MindIE-Torch/built-in/audio/whisperX/modeling_whisper.py +++ b/MindIE/MindIE-Torch/built-in/audio/whisperX/modeling_whisper.py @@ -566,30 +566,30 @@ class MindieWhisperForConditionalGeneration(WhisperForConditionalGeneration, Gen decoder = Decoder(mindie_whisper) print("Start compiling decoder.") - layer_nums = 32 encoder_outputs = torch.randn([args.bs, self.encoder_seq_len, self.hidden_size]) decoder_input_ids = torch.randint(1, 4, (args.bs, 1)) actual_seq_len = torch.ones((args.bs, 1)) all_past_key_value = [torch.randn([args.bs, self.seq_len, self.head_num, self.head_size]), torch.randn([args.bs, self.seq_len, self.head_num, self.head_size]), torch.randn([args.bs, self.encoder_seq_len, self.head_num, self.head_size]), - torch.randn([args.bs, self.encoder_seq_len, self.head_num, self.head_size])] * layer_nums + torch.randn([args.bs, self.encoder_seq_len, self.head_num, self.head_size]) + ] * self.layer_nums traced_args = [decoder_input_ids, encoder_outputs, actual_seq_len] traced_args.extend(all_past_key_value) traced_decoder = torch.jit.trace(decoder, traced_args) # BSND key_value_infos = [ - mindietorch.Input(shape=(args.bs, max_len, head_num, head_size), + mindietorch.Input(shape=(args.bs, self.max_len, self.head_num, self.head_size), dtype=dtype.FLOAT16), - mindietorch.Input(shape=(args.bs, max_len, head_num, head_size), + mindietorch.Input(shape=(args.bs, self.max_len, self.head_num, self.head_size), dtype=dtype.FLOAT16), - mindietorch.Input(shape=(args.bs, head_num, self.encoder_seq_len, head_size), + mindietorch.Input(shape=(args.bs, self.head_num, self.encoder_seq_len, self.head_size), dtype=dtype.FLOAT16), - mindietorch.Input(shape=(args.bs, head_num, self.encoder_seq_len, head_size), + mindietorch.Input(shape=(args.bs, self.head_num, self.encoder_seq_len, self.head_size), dtype=dtype.FLOAT16 - )] * layer_nums + )] * self.layer_nums input_info = [mindietorch.Input(shape=(args.bs, 1), dtype=dtype.INT64), - mindietorch.Input(shape=(args.bs, self.encoder_seq_len, hidden_size)), + mindietorch.Input(shape=(args.bs, self.encoder_seq_len, self.hidden_size)), mindietorch.Input(shape=(args.bs, 1), dtype=dtype.INT64)] input_info.extend(key_value_infos) @@ -649,8 +649,8 @@ class MindieWhisperForConditionalGeneration(WhisperForConditionalGeneration, Gen print(f"Directory {save_path} created.") else: print(f"Directory {save_path} already exists.") - self.compile_encoder(args) - self.compile_prefill_decoder(args) + # self.compile_encoder(args) + # self.compile_prefill_decoder(args) self.compile_decoder(args) self.compile_scatter_update(args) self.has_compile = True -- Gitee From 82cb643fffbbb71b58b1eda64715a30b1a44fa66 Mon Sep 17 00:00:00 2001 From: tiansongxue Date: Tue, 29 Oct 2024 17:08:11 +0800 Subject: [PATCH 34/44] fix33333 --- MindIE/MindIE-Torch/built-in/audio/whisperX/modeling_whisper.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/MindIE/MindIE-Torch/built-in/audio/whisperX/modeling_whisper.py b/MindIE/MindIE-Torch/built-in/audio/whisperX/modeling_whisper.py index e2ff01ffa3..c63e0a37fd 100644 --- a/MindIE/MindIE-Torch/built-in/audio/whisperX/modeling_whisper.py +++ b/MindIE/MindIE-Torch/built-in/audio/whisperX/modeling_whisper.py @@ -568,7 +568,7 @@ class MindieWhisperForConditionalGeneration(WhisperForConditionalGeneration, Gen encoder_outputs = torch.randn([args.bs, self.encoder_seq_len, self.hidden_size]) decoder_input_ids = torch.randint(1, 4, (args.bs, 1)) - actual_seq_len = torch.ones((args.bs, 1)) + actual_seq_len = torch.ones((args.bs)) all_past_key_value = [torch.randn([args.bs, self.seq_len, self.head_num, self.head_size]), torch.randn([args.bs, self.seq_len, self.head_num, self.head_size]), torch.randn([args.bs, self.encoder_seq_len, self.head_num, self.head_size]), -- Gitee From 0e668c79b59dd63e80c31b204e4fcca2edc552f3 Mon Sep 17 00:00:00 2001 From: tiansongxue Date: Tue, 29 Oct 2024 17:41:07 +0800 Subject: [PATCH 35/44] fix 44444444 --- .../MindIE-Torch/built-in/audio/whisperX/modeling_whisper.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/MindIE/MindIE-Torch/built-in/audio/whisperX/modeling_whisper.py b/MindIE/MindIE-Torch/built-in/audio/whisperX/modeling_whisper.py index c63e0a37fd..96fd3414c4 100644 --- a/MindIE/MindIE-Torch/built-in/audio/whisperX/modeling_whisper.py +++ b/MindIE/MindIE-Torch/built-in/audio/whisperX/modeling_whisper.py @@ -590,7 +590,7 @@ class MindieWhisperForConditionalGeneration(WhisperForConditionalGeneration, Gen )] * self.layer_nums input_info = [mindietorch.Input(shape=(args.bs, 1), dtype=dtype.INT64), mindietorch.Input(shape=(args.bs, self.encoder_seq_len, self.hidden_size)), - mindietorch.Input(shape=(args.bs, 1), dtype=dtype.INT64)] + mindietorch.Input(shape=(args.bs), dtype=dtype.INT64)] input_info.extend(key_value_infos) float_size = 4 @@ -745,7 +745,7 @@ class MindieWhisperForConditionalGeneration(WhisperForConditionalGeneration, Gen args.append(torch.tensor([kv_actual_step] * input_ids.shape[0])).to("npu") args.extend(self.past_key_value) outputs = self.mindie_decoder(*args) - if isinstacnce(outputs, list): + if isinstance(outputs, list): next_token_logits = outputs[0].to("cpu")[:, -1, :] else: next_token_logits = outputs.to("cpu")[:, -1, :] -- Gitee From 5b1970d5b4e2c7a1b3c2f8a109ce05bc4875d110 Mon Sep 17 00:00:00 2001 From: tiansongxue Date: Tue, 29 Oct 2024 19:27:33 +0800 Subject: [PATCH 36/44] fix555555555555555 --- .../built-in/audio/whisperX/modeling_whisper.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/MindIE/MindIE-Torch/built-in/audio/whisperX/modeling_whisper.py b/MindIE/MindIE-Torch/built-in/audio/whisperX/modeling_whisper.py index 96fd3414c4..cb3f1ddf81 100644 --- a/MindIE/MindIE-Torch/built-in/audio/whisperX/modeling_whisper.py +++ b/MindIE/MindIE-Torch/built-in/audio/whisperX/modeling_whisper.py @@ -583,14 +583,14 @@ class MindieWhisperForConditionalGeneration(WhisperForConditionalGeneration, Gen dtype=dtype.FLOAT16), mindietorch.Input(shape=(args.bs, self.max_len, self.head_num, self.head_size), dtype=dtype.FLOAT16), - mindietorch.Input(shape=(args.bs, self.head_num, self.encoder_seq_len, self.head_size), + mindietorch.Input(shape=(args.bs, self.encoder_seq_len, self.head_num, self.head_size), dtype=dtype.FLOAT16), - mindietorch.Input(shape=(args.bs, self.head_num, self.encoder_seq_len, self.head_size), + mindietorch.Input(shape=(args.bs, self.encoder_seq_len, self.head_num, self.head_size), dtype=dtype.FLOAT16 )] * self.layer_nums input_info = [mindietorch.Input(shape=(args.bs, 1), dtype=dtype.INT64), mindietorch.Input(shape=(args.bs, self.encoder_seq_len, self.hidden_size)), - mindietorch.Input(shape=(args.bs), dtype=dtype.INT64)] + mindietorch.Input(shape=(args.bs, ), dtype=dtype.INT64)] input_info.extend(key_value_infos) float_size = 4 @@ -615,11 +615,11 @@ class MindieWhisperForConditionalGeneration(WhisperForConditionalGeneration, Gen class MindieScatter(torch.nn.Module): def forward(self, past_key_value, indices, update_states): - out = torch.ops.aie.scatter_update(self, past_key_value, indices, update_states, 1) + out = torch.ops.aie.scatter_update(past_key_value, indices, update_states, 1) return out bs = args.bs self_past_key_value = torch.randn([bs, self.max_len, self.head_num, self.head_size]) - encoder_past_key_value = torch.rand([bs, self.max_len, self.head_num, self.head_size]) + encoder_past_key_value = torch.randn([bs, self.max_len, self.head_num, self.head_size]) indices = torch.tensor([0]*bs) update_states = torch.randn([bs, 1, 20, 64]) traced = torch.jit.trace(MindieScatter(), (self_past_key_value, indices, update_states)) -- Gitee From 0471320cb80346776d308174e6123af619266fa0 Mon Sep 17 00:00:00 2001 From: tiansongxue Date: Fri, 8 Nov 2024 09:20:13 +0800 Subject: [PATCH 37/44] whisperX --- .../audio/whisperX/compile_whisper.py | 153 ++++++++++++++- .../audio/whisperX/modeling_whisper.py | 176 +++++++++++++----- .../built-in/audio/whisperX/readme.md | 80 ++++++++ 3 files changed, 360 insertions(+), 49 deletions(-) create mode 100644 MindIE/MindIE-Torch/built-in/audio/whisperX/readme.md diff --git a/MindIE/MindIE-Torch/built-in/audio/whisperX/compile_whisper.py b/MindIE/MindIE-Torch/built-in/audio/whisperX/compile_whisper.py index 217b42d96e..77b5efaf57 100644 --- a/MindIE/MindIE-Torch/built-in/audio/whisperX/compile_whisper.py +++ b/MindIE/MindIE-Torch/built-in/audio/whisperX/compile_whisper.py @@ -4,6 +4,143 @@ import mindietorch from modeling_whisper import MindieWhisperForConditionalGeneration +def compile_encoder(model : MindieWhisperForConditionalGeneration, + args : argparse): + encoder = self.get_encoder() + + class Encoder(torch.nn.Module): + + def __init__(self, model): + super().__init__() + self.model = model + + def forward(self, input_features): + return self.model(input_features=input_features, return_dict=False) + + mel_feature_size = 128 + max_frames = 3000 + input_features = torch.randn([args.bs, mel_feature_size, max_frames]) + encoder_traced = torch.jit.trace(Encoder(encoder), (input_features)) + input_info = [mindietorch.Input(shape=(args.bs, mel_feature_size, max_frames))] + print("Start AOE optimization_level 1.") + compiled = mindietorch.compile(encoder_traced, + inputs=input_info, + precision_policy=mindietorch.PrecisionPolicy.FP16, + soc_version=args.soc_version, + optimization_level=1 + ) + print("AOE optimize finished and start compile encoder.") + compiled = mindietorch.compile(encoder_traced, + inputs=input_info, + precision_policy=mindietorch.PrecisionPolicy.FP16, + soc_version=args.soc_version, + ) + save_file = os.path.join(args.save_path, f"{self.file_prefix_names[0]}{args.bs}.ts") + torch.jit.save(compiled, save_file) + print(f"Compile encoder success, saved in {save_file}") + +def compile_prefill_decoder(self, args): + print("Start compiling prefill_decoder.") + encoder_seq_len = 1500 + hidden_size = 1280 + encoder_outputs = torch.randn([args.bs, encoder_seq_len, hidden_size]) + decoder_input_ids = torch.randint(1, 4, (args.bs, 1)) + prefill_decoder_traced = torch.jit.trace(self.eval(), (decoder_input_ids, encoder_outputs)) + input_info = [mindietorch.Input(shape=(args.bs, 1), dtype=dtype.INT64), + mindietorch.Input(shape=(args.bs, encoder_seq_len, hidden_size))] + prefill_decoder_compiled = mindietorch.compile(prefill_decoder_traced, + inputs=input_info, + precision_policy=mindietorch.PrecisionPolicy.FP16, + soc_version=args.soc_version) + + save_file = os.path.join(args.save_path, f"{self.file_prefix_names[1]}{args.bs}.ts") + torch.jit.save(prefill_decoder_compiled, save_file) + print(f"Compile prefill_decoder success, saved in {save_file}.") + +def compile_incre_decoder(args): + class Decoder(torch.nn.Module): + + def __init__(self, model): + super().__init__() + self.model = model + + def forward(self, *args): + return self.model.forward(*args)[0] + + mindie_whisper = MindieWhisperForConditionalGeneration.from_pretrained(args.model_path, + is_incre_decode=True, + soc_version=args.soc_version) + decoder = Decoder(mindie_whisper) + print("Start compiling decoder.") + + encoder_outputs = torch.randn([args.bs, self.encoder_seq_len, self.hidden_size]) + decoder_input_ids = torch.randint(1, 4, (args.bs, 1)) + actual_seq_len = torch.ones((args.bs)) + all_past_key_value = [torch.randn([args.bs, self.seq_len, self.head_num, self.head_size]), + torch.randn([args.bs, self.seq_len, self.head_num, self.head_size]), + torch.randn([args.bs, self.encoder_seq_len, self.head_num, self.head_size]), + torch.randn([args.bs, self.encoder_seq_len, self.head_num, self.head_size]) + ] * self.layer_nums + traced_args = [decoder_input_ids, encoder_outputs, actual_seq_len] + traced_args.extend(all_past_key_value) + traced_decoder = torch.jit.trace(decoder, traced_args) + # BSND + key_value_infos = [ + mindietorch.Input(shape=(args.bs, self.max_len, self.head_num, self.head_size), + dtype=dtype.FLOAT16), + mindietorch.Input(shape=(args.bs, self.max_len, self.head_num, self.head_size), + dtype=dtype.FLOAT16), + mindietorch.Input(shape=(args.bs, self.encoder_seq_len, self.head_num, self.head_size), + dtype=dtype.FLOAT16), + mindietorch.Input(shape=(args.bs, self.encoder_seq_len, self.head_num, self.head_size), + dtype=dtype.FLOAT16 + )] * self.layer_nums + input_info = [mindietorch.Input(shape=(args.bs, 1), dtype=dtype.INT64), # input ids + mindietorch.Input(shape=(args.bs, self.encoder_seq_len, self.hidden_size)), + mindietorch.Input(shape=(args.bs,), dtype=dtype.INT64)] # actual sq len + + input_info.extend(key_value_infos) + float_size = 4 + voc_size = 51866 + buffer_size = math.ceil((args.bs * 1 * voc_size * float_size) / 1024 / 1024) + print(f"Set {buffer_size}/MB for output.") + compiled_decoder = mindietorch.compile(traced_decoder, + inputs=input_info, + precision_policy=mindietorch.PrecisionPolicy.FP16, + soc_version=args.soc_version, + default_buffer_size_vec=[buffer_size]) + save_file = os.path.join(args.save_path, f"{self.file_prefix_names[2]}{args.bs}.ts") + torch.jit.save(compiled_decoder, save_file) + + print(f"Compile whisper_decoder success, saved in {save_file}.") + +def compile_scatter_update(self, args): + class MindieScatter(torch.nn.Module): + def forward(self, past_key_value, indices, update_states): + out = torch.ops.aie.scatter_update(past_key_value, indices, update_states, 1) + return out + + bs = args.bs + self_past_key_value = torch.randn([bs, self.max_len, self.head_num, self.head_size]) + encoder_past_key_value = torch.randn([bs, self.max_len, self.head_num, self.head_size]) + indices = torch.tensor([0] * bs) + update_states = torch.randn([bs, 1, 20, 64]) + traced = torch.jit.trace(MindieScatter(), (self_past_key_value, indices, update_states)) + + self_attn_info = mindietorch.Input(shape=self_past_key_value.shape, dtype=mindietorch.dtype.FLOAT16) + encoder_attn_info = mindietorch.Input(shape=encoder_past_key_value.shape, dtype=mindietorch.dtype.FLOAT16) + indices_info = mindietorch.Input(shape=indices.shape, dtype=mindietorch.dtype.INT64) + update_states_info = mindietorch.Input(shape=encoder_past_key_value.shape, dtype=mindietorch.dtype.FLOAT16) + + compile_self = mindietorch.compile(traced, inputs=[self_attn_info, indices_info, update_states_info], + soc_version=args.soc_version) + torch.jit.save(compile_self, f"{args.save_path}/{self.file_prefix_names[3]}{args.bs}.ts") + + compile_self = mindietorch.compile(traced, inputs=[encoder_attn_info, indices_info, update_states_info], + soc_version=args.soc_version) + torch.jit.save(compile_self, f"{args.save_path}/{self.file_prefix_names[4]}{args.bs}.ts") + print("compile scatter success.") + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument('-model_path', type=str, required=True, help="please provide model path.") @@ -14,6 +151,18 @@ if __name__ == "__main__": parser.add_argument('-device_id', type=int, default=0) args = parser.parse_args() - mindie_whisper = MindieWhisperForConditionalGeneration.from_pretrained(args.model_path) + mindie_whisper = MindieWhisperForConditionalGeneration.from_pretrained(args.model_path, args.soc_version) device = f"npu:{args.device_id}" - mindie_whisper.compile(args) \ No newline at end of file + print("Start compiling Mindie-Whisper, it will take some time, please wait.") + if not args.save_path: + raise ValueError("Please provide the directory where the compiled model saved.") + if not os.path.exists(args.save_path): + os.makedirs(args.save_path) + print(f"Directory {args.save_path} created.") + else: + print(f"Directory {args.save_path} already exists.") + + compile_scatter_update(mindie_whisper, args) + compile_encoder(mindie_whisper, args) + compile_prefill_decoder(mindie_whisper, args) + compile_incre_decoder(args) \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/audio/whisperX/modeling_whisper.py b/MindIE/MindIE-Torch/built-in/audio/whisperX/modeling_whisper.py index cb3f1ddf81..6e33a9d817 100644 --- a/MindIE/MindIE-Torch/built-in/audio/whisperX/modeling_whisper.py +++ b/MindIE/MindIE-Torch/built-in/audio/whisperX/modeling_whisper.py @@ -138,24 +138,21 @@ class MindieIFA(WhisperAttention): **kwargs ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: # self attn - assert past_key_value is not None + assert past_key_value is not None, \ + "Current operation is incre_flash_attention, past_key_value is required." bsz, tgt_len, _ = hidden_states.size() - assert tgt_len == 1 + assert tgt_len == 1, \ + "Current operation is incre_flash_attention, query's seq length should be equal to 1." query_states = self.q_proj(hidden_states) - key_states = self._shape(self.k_proj(hidden_states), -1, bsz).to(torch.float16) values_states = self._shape(self.v_proj(hidden_states), -1, bsz).to(torch.float16) - past_key_cache, past_value_cache = past_key_value[0], past_key_value[1] - indices = actual_seq_len - 1 past_key_cache = torch.ops.aie.scatter_update(past_key_cache, indices, key_states, axis=1) past_value_cache = torch.ops.aie.scatter_update(past_value_cache, indices, values_states, axis=1) - past_key_value = (past_key_cache, past_value_cache) # B S N D query_states = query_states.view(bsz, tgt_len, self.num_heads, self.head_dim).contiguous() - # B S N D attn_output = torch.ops.aie.incre_flash_attention( query=query_states, @@ -165,12 +162,90 @@ class MindieIFA(WhisperAttention): num_head=self.num_heads, scale=self.scaling, layout="BSND") - attn_output = attn_output.view(bsz, tgt_len, self.num_heads, self.head_dim) attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) attn_output = self.out_proj(attn_output) return attn_output, _, past_key_value +class MindieFA(WhisperAttention): + + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: float = 0.0, + is_decoder: bool = False, + bias: bool = True, + is_causal: bool = False, + config: bool = None + ): + + super().__init__( + embed_dim, + num_heads, + dropout, + bias, + is_causal, + config) + + # BSND + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + **kwargs + + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + # cross_attn + is_cross_attn = key_value_states is not None + bsz, tgt_len, _ = hidden_states.size() + query_states = self.q_proj(hidden_states) + if ( + is_cross_attn + and past_key_value is not None + # past_key_value layout is BSND + and past_key_value[0].shape[1] == key_value_states.shape[1] + ): + # reuse k, v + key_states = past_key_value[0] + value_states = past_key_value[1] + elif is_cross_attn: + key_states = self._shape(self.k_proj(key_value_states), -1, bsz) + value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + # self attn + elif past_key_value is not None: + # reuse k v + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = torch.cat([past_key_value[0], key_states], dim=1) + value_states = torch.cat([past_key_value[1], value_states], dim=1) + else: + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + + if self.is_decoder: + past_key_value = (key_states.to(torch.float16), value_states.to(torch.float16)) + + # MindFA only support BNSD layout + query_states = query_states.view(bsz, tgt_len, self.num_heads, self.head_dim).contiguous().transpose(1, 2) + attn_output = torch.ops.aie.flash_attention( + query=query_states, + key=key_states.transpose(1, 2), + value=value_states.transpose(1, 2), + num_head=self.num_heads, + scale=self.scaling, + layout="BNSD", + type="FA_HIGH_PERF" + ) + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.view(bsz, tgt_len, self.num_heads, self.head_dim) + attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + attn_output = self.out_proj(attn_output) + return attn_output, _, past_key_value class MindieAttention(WhisperAttention): def __init__( @@ -233,29 +308,51 @@ class MindieAttention(WhisperAttention): class MindieWhisperDecoderLayer(WhisperDecoderLayer): - def __init__(self, config): + def __init__(self, config, is_incre_decode,soc_version): super().__init__(config) self.embed_dim = config.d_model - if config.is_use_ifa: - self.self_attn = MindieIFA( + if is_incre_decode: + self.self_attn = MindieIFA( + embed_dim=self.embed_dim, + num_heads=config.decoder_attention_heads, + config=config, + is_decoder=True + ) + self.encoder_attn = MindieIFA( + embed_dim=self.embed_dim, + num_heads=config.decoder_attention_heads, + is_decoder=True, + config=config + ) + elif soc_version == "Ascend910B4": + self.self_attn = MindiePFA( + embed_dim=self.embed_dim, + num_heads=config.decoder_attention_heads, + config=config, + is_decoder=True + ) + self.encoder_attn = MindiePFA( embed_dim=self.embed_dim, num_heads=config.decoder_attention_heads, - config=config, - is_decoder=True + is_decoder=True, + config=config ) - else: - self.self_attn = MindiePFA( + elif soc_version == "Ascend310P3": + self.self_attn = MindieFA( + embed_dim=self.embed_dim, + num_heads=config.decoder_attention_heads, + config=config, + is_decoder=True + ) + self.encoder_attn = MindieFA( embed_dim=self.embed_dim, num_heads=config.decoder_attention_heads, - config=config, - is_decoder=True + is_decoder=True, + config=config ) - self.encoder_attn = MindiePFA( - embed_dim=self.embed_dim, - num_heads=config.decoder_attention_heads, - is_decoder=True, - config=config - ) + else: + raise ValueError(f"Unsupporting current parameters. eg. " + f"is_incre_decode: {is_incre_decode} soc_version {soc_version}") def forward( self, @@ -303,7 +400,6 @@ class MindieWhisperDecoderLayer(WhisperDecoderLayer): hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states - # add cross-attn to positions 3,4 of present_key_value tuple present_key_value = present_key_value + cross_attn_present_key_value # Fully Connected @@ -314,7 +410,6 @@ class MindieWhisperDecoderLayer(WhisperDecoderLayer): hidden_states = self.fc2(hidden_states) hidden_states = hidden_states.reshape(-1, 1, self.embed_dim) hidden_states = residual + hidden_states - outputs = (hidden_states,) if use_cache: @@ -325,9 +420,10 @@ class MindieWhisperDecoderLayer(WhisperDecoderLayer): class MindieWhisperDecoder(WhisperDecoder): - def __init__(self, config): + def __init__(self, config, is_incre_decode, soc_version): super().__init__(config) - self.layers = nn.ModuleList([MindieWhisperDecoderLayer(config) for _ in range(config.decoder_layers)]) + self.layers = nn.ModuleList([MindieWhisperDecoderLayer(config, is_incre_decode, soc_version) + for _ in range(config.decoder_layers)]) self.config = config def forward( @@ -346,10 +442,6 @@ class MindieWhisperDecoder(WhisperDecoder): input_ids = input_ids.view(-1, input_shape[-1]) # B S N D past_key_values_length = past_key_values[0].shape[1] if past_key_values is not None else 0 - - attention_mask = _prepare_4d_causal_attention_mask( - attention_mask, input_shape, inputs_embeds, past_key_values_length - ) inputs_embeds = self.embed_tokens(input_ids) # embed positions positions = self.embed_positions(input_ids, past_key_values_length=past_key_values_length) @@ -401,9 +493,9 @@ class MindieWhisperEncoder(WhisperEncoder): class MindieWhisperModel(WhisperModel): - def __init__(self, config): + def __init__(self, config, is_incre_decode, soc_version): super().__init__(config) - self.decoder = MindieWhisperDecoder(config) + self.decoder = MindieWhisperDecoder(config, is_incre_decode, soc_version) self.encoder = MindieWhisperEncoder(config) def forward( @@ -443,10 +535,10 @@ class MindieWhisperModel(WhisperModel): class MindieWhisperForConditionalGeneration(WhisperForConditionalGeneration, GenerationMixin): - def __init__(self, config, enable_incre_flash_attention=False): + def __init__(self, config, is_incre_decode=False, soc_version="Ascend310P3"): super().__init__(config) config.is_use_ifa = enable_incre_flash_attention - self.model = MindieWhisperModel(config) + self.model = MindieWhisperModel(config, is_incre_decode, soc_version) self.has_load = False self.has_compile = False self.mindie_encoder = None @@ -524,9 +616,6 @@ class MindieWhisperForConditionalGeneration(WhisperForConditionalGeneration, Gen ) save_file = os.path.join(args.save_path, f"{self.file_prefix_names[0]}{args.bs}.ts") torch.jit.save(compiled, save_file) - - traced_file = os.path.join(args.save_path, f"{self.file_prefix_names[0]}{args.bs}.pt") - torch.jit.save(encoder_traced, traced_file) print(f"Compile encoder success, saved in {save_file}") def compile_prefill_decoder(self, args): @@ -543,9 +632,6 @@ class MindieWhisperForConditionalGeneration(WhisperForConditionalGeneration, Gen precision_policy=mindietorch.PrecisionPolicy.FP16, soc_version=args.soc_version) - traced_file = os.path.join(args.save_path, f"{self.file_prefix_names[1]}{args.bs}.pt") - torch.jit.save(prefill_decoder_traced, traced_file) - save_file = os.path.join(args.save_path, f"{self.file_prefix_names[1]}{args.bs}.ts") torch.jit.save(prefill_decoder_compiled, save_file) print(f"Compile prefill_decoder success, saved in {save_file}.") @@ -602,10 +688,6 @@ class MindieWhisperForConditionalGeneration(WhisperForConditionalGeneration, Gen precision_policy=mindietorch.PrecisionPolicy.FP16, soc_version=args.soc_version, default_buffer_size_vec=[buffer_size]) - - traced_file = os.path.join(args.save_path, f"{self.file_prefix_names[2]}{args.bs}.pt") - torch.jit.save(traced_decoder, traced_file) - save_file = os.path.join(args.save_path, f"{self.file_prefix_names[2]}{args.bs}.ts") torch.jit.save(compiled_decoder, save_file) @@ -649,8 +731,8 @@ class MindieWhisperForConditionalGeneration(WhisperForConditionalGeneration, Gen print(f"Directory {save_path} created.") else: print(f"Directory {save_path} already exists.") - # self.compile_encoder(args) - # self.compile_prefill_decoder(args) + self.compile_encoder(args) + self.compile_prefill_decoder(args) self.compile_decoder(args) self.compile_scatter_update(args) self.has_compile = True diff --git a/MindIE/MindIE-Torch/built-in/audio/whisperX/readme.md b/MindIE/MindIE-Torch/built-in/audio/whisperX/readme.md new file mode 100644 index 0000000000..c9f020d143 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/audio/whisperX/readme.md @@ -0,0 +1,80 @@ +# Whisper-large-v3模型-推理指导 + +- [概述](#概述) +- [推理环境准备](#推理环境准备) +- [快速上手](#快速上手) + - [获取源码](#获取源码) + - [模型推理](#模型推理) + +# 概述 + +该工程使用mindietorch部署whisper-large-v3模型 + + +# 推理环境准备 + +- 该模型需要以下插件与驱动 + + **表 1** 版本配套表 + + | 配套 | 版本 | 环境准备指导 | + |-----------| ------- | ------------ | + | Python | 3.10.13 | - | + | torch | 2.1.0+cpu | - | + | torch_audio | 2.1.0+cpu | - | + | CANN | 8.0.RC3 | - | + | MindIE | 1.0.RC3 | - | + +# 快速上手 +## 获取源码 + +1. 安装mindie包 + + ```bash + # 安装mindie + chmod +x ./Ascend-mindie_xxx.run + ./Ascend-mindie_xxx.run --install + source /usr/local/Ascend/mindie/set_env.sh + ``` + + +2. 模型权重下载路径: + ```bash + https://huggingface.co/openai/whisper-large-v3/tree/main + ``` + 将权重文件存放至当前目录下的model_path文件夹,请先创建改文件夹。 + + +5. 安装依赖 + ``` + pip3 install torch==2.1.0 --index-url https://download.pytorch.org/whl/cpu + pip3 install nltk + pip3 install librosa + pip3 install transformers==4.36.0 + pip3 install numpy==1.26.0 + ``` + +## 模型推理 +1. 设置mindie内存池上限为32,执行如下命令设置环境变量。内存池设置过小,内存重复申请和释放会影响性能。 + ``` + export TORCH_AIE_NPU_CACHE_MAX_SIZE=32 + ``` + +2. 模型编译和推理 + ``` + python3 compile_whisper.py \ + -model_path ./model_path \ + -bs 16 \ + -save_path ./compiled_models \ + -soc_version Ascend310P3 + ``` + + 参数说明: + - -model_path:预训练模型路径,必选。 + - -bs:batch_size, 默认值为8, 可选。 + - -save_path: 编译好的模型的保存文件,必选。 + - -device_id: 选在模型运行的卡编号,默认值0,可选。 + - -soc_version: 芯片类型,默认值:Ascend310P3,仅支持配置Ascend310P3或者Ascend910B4。 + + 约束说明: + 1. 当前暂不支持动态batch,batch_size改变后,需要重新编图。 -- Gitee From 198452a2fdc2e07a387ad3e0bd7fc4bd61fc859b Mon Sep 17 00:00:00 2001 From: tiansongxue Date: Fri, 8 Nov 2024 09:42:40 +0800 Subject: [PATCH 38/44] whisperX --- .../audio/whisperX/modeling_whisper.py | 67 +++++++------------ 1 file changed, 25 insertions(+), 42 deletions(-) diff --git a/MindIE/MindIE-Torch/built-in/audio/whisperX/modeling_whisper.py b/MindIE/MindIE-Torch/built-in/audio/whisperX/modeling_whisper.py index 6e33a9d817..87fada7d47 100644 --- a/MindIE/MindIE-Torch/built-in/audio/whisperX/modeling_whisper.py +++ b/MindIE/MindIE-Torch/built-in/audio/whisperX/modeling_whisper.py @@ -23,7 +23,6 @@ import mindietorch from mindietorch._enums import dtype - class MindiePFA(WhisperAttention): def __init__( @@ -34,16 +33,14 @@ class MindiePFA(WhisperAttention): is_decoder: bool = False, bias: bool = True, is_causal: bool = False, - config: bool = None - ): - + config: bool = None ): super().__init__( embed_dim, num_heads, dropout, bias, is_causal, - config) + config ) # BSND def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): @@ -54,9 +51,7 @@ class MindiePFA(WhisperAttention): hidden_states: torch.Tensor, key_value_states: Optional[torch.Tensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, - **kwargs - - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + **kwargs ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: # cross_attn is_cross_attn = key_value_states is not None bsz, tgt_len, _ = hidden_states.size() @@ -98,8 +93,7 @@ class MindiePFA(WhisperAttention): num_head=self.num_heads, scale=self.scaling, layout="BSND", - type="PFA" - ) + type="PFA" ) attn_output = attn_output.view(bsz, tgt_len, self.num_heads, self.head_dim) attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) attn_output = self.out_proj(attn_output) @@ -114,8 +108,7 @@ class MindieIFA(WhisperAttention): is_decoder: bool = False, bias: bool = True, is_causal: bool = True, - config=None - ): + config=None ): super().__init__( embed_dim, num_heads, @@ -123,8 +116,7 @@ class MindieIFA(WhisperAttention): is_decoder, bias, is_causal, - config - ) + config ) # BSND def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): @@ -135,8 +127,7 @@ class MindieIFA(WhisperAttention): hidden_states: torch.Tensor, past_key_value: torch.Tensor, actual_seq_len: torch.Tensor, - **kwargs - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + **kwargs ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: # self attn assert past_key_value is not None, \ "Current operation is incre_flash_attention, past_key_value is required." @@ -177,9 +168,7 @@ class MindieFA(WhisperAttention): is_decoder: bool = False, bias: bool = True, is_causal: bool = False, - config: bool = None - ): - + config: bool = None ): super().__init__( embed_dim, num_heads, @@ -197,9 +186,7 @@ class MindieFA(WhisperAttention): hidden_states: torch.Tensor, key_value_states: Optional[torch.Tensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, - **kwargs - - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + **kwargs ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: # cross_attn is_cross_attn = key_value_states is not None bsz, tgt_len, _ = hidden_states.size() @@ -239,8 +226,7 @@ class MindieFA(WhisperAttention): num_head=self.num_heads, scale=self.scaling, layout="BNSD", - type="FA_HIGH_PERF" - ) + type="FA_HIGH_PERF") attn_output = attn_output.transpose(1, 2) attn_output = attn_output.view(bsz, tgt_len, self.num_heads, self.head_dim) attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) @@ -256,8 +242,7 @@ class MindieAttention(WhisperAttention): is_decoder: bool = False, bias: bool = True, is_causal: bool = False, - config=None - ): + config=None ): super().__init__( embed_dim, num_heads, @@ -265,8 +250,7 @@ class MindieAttention(WhisperAttention): is_decoder, bias, is_causal, - config - ) + config ) # BSND def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): @@ -277,8 +261,7 @@ class MindieAttention(WhisperAttention): hidden_states: torch.Tensor, key_value_states: Optional[torch.Tensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, - **kwargs - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + **kwargs ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, tgt_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states) @@ -297,8 +280,7 @@ class MindieAttention(WhisperAttention): num_head=self.num_heads, scale=self.scaling, layout="BSND", - type="PFA" - ) + type="PFA") attn_output = attn_output.view(bsz, tgt_len, self.num_heads, self.head_dim) attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) @@ -308,10 +290,10 @@ class MindieAttention(WhisperAttention): class MindieWhisperDecoderLayer(WhisperDecoderLayer): - def __init__(self, config, is_incre_decode,soc_version): + def __init__(self, config): super().__init__(config) self.embed_dim = config.d_model - if is_incre_decode: + if config.is_incre_decode: self.self_attn = MindieIFA( embed_dim=self.embed_dim, num_heads=config.decoder_attention_heads, @@ -324,7 +306,7 @@ class MindieWhisperDecoderLayer(WhisperDecoderLayer): is_decoder=True, config=config ) - elif soc_version == "Ascend910B4": + elif config.soc_version == "Ascend910B4": self.self_attn = MindiePFA( embed_dim=self.embed_dim, num_heads=config.decoder_attention_heads, @@ -337,7 +319,7 @@ class MindieWhisperDecoderLayer(WhisperDecoderLayer): is_decoder=True, config=config ) - elif soc_version == "Ascend310P3": + elif config.soc_version == "Ascend310P3": self.self_attn = MindieFA( embed_dim=self.embed_dim, num_heads=config.decoder_attention_heads, @@ -420,9 +402,9 @@ class MindieWhisperDecoderLayer(WhisperDecoderLayer): class MindieWhisperDecoder(WhisperDecoder): - def __init__(self, config, is_incre_decode, soc_version): + def __init__(self, config): super().__init__(config) - self.layers = nn.ModuleList([MindieWhisperDecoderLayer(config, is_incre_decode, soc_version) + self.layers = nn.ModuleList([MindieWhisperDecoderLayer(config) for _ in range(config.decoder_layers)]) self.config = config @@ -493,9 +475,9 @@ class MindieWhisperEncoder(WhisperEncoder): class MindieWhisperModel(WhisperModel): - def __init__(self, config, is_incre_decode, soc_version): + def __init__(self, config): super().__init__(config) - self.decoder = MindieWhisperDecoder(config, is_incre_decode, soc_version) + self.decoder = MindieWhisperDecoder(config) self.encoder = MindieWhisperEncoder(config) def forward( @@ -537,8 +519,9 @@ class MindieWhisperForConditionalGeneration(WhisperForConditionalGeneration, Gen def __init__(self, config, is_incre_decode=False, soc_version="Ascend310P3"): super().__init__(config) - config.is_use_ifa = enable_incre_flash_attention - self.model = MindieWhisperModel(config, is_incre_decode, soc_version) + config.is_incre_decode = is_incre_decode + config.soc_version = soc_version + self.model = MindieWhisperModel(config) self.has_load = False self.has_compile = False self.mindie_encoder = None -- Gitee From 2b9e8f1ceace3da6896766e63f0dac79963e47f5 Mon Sep 17 00:00:00 2001 From: tiansongxue Date: Fri, 8 Nov 2024 16:27:03 +0800 Subject: [PATCH 39/44] fix prefix --- .../audio/whisperX/compile_whisper.py | 41 ++++++++++++------- .../audio/whisperX/modeling_whisper.py | 4 +- 2 files changed, 28 insertions(+), 17 deletions(-) diff --git a/MindIE/MindIE-Torch/built-in/audio/whisperX/compile_whisper.py b/MindIE/MindIE-Torch/built-in/audio/whisperX/compile_whisper.py index 77b5efaf57..ee857fe6fa 100644 --- a/MindIE/MindIE-Torch/built-in/audio/whisperX/compile_whisper.py +++ b/MindIE/MindIE-Torch/built-in/audio/whisperX/compile_whisper.py @@ -1,12 +1,19 @@ +import os import argparse import time +import torch import mindietorch from modeling_whisper import MindieWhisperForConditionalGeneration - +FILE_REFIX_NAME = ["mindie_whisper_encoder_bs", + "mindie_decoder_prefill_bs", + "mindie_whisper_decoder_bs", + "mindie_self_scatter_bs", + "mindie_encoder_scatter_bs" + ] def compile_encoder(model : MindieWhisperForConditionalGeneration, args : argparse): - encoder = self.get_encoder() + encoder = model.get_encoder() class Encoder(torch.nn.Module): @@ -22,20 +29,24 @@ def compile_encoder(model : MindieWhisperForConditionalGeneration, input_features = torch.randn([args.bs, mel_feature_size, max_frames]) encoder_traced = torch.jit.trace(Encoder(encoder), (input_features)) input_info = [mindietorch.Input(shape=(args.bs, mel_feature_size, max_frames))] - print("Start AOE optimization_level 1.") - compiled = mindietorch.compile(encoder_traced, - inputs=input_info, - precision_policy=mindietorch.PrecisionPolicy.FP16, - soc_version=args.soc_version, - optimization_level=1 - ) - print("AOE optimize finished and start compile encoder.") + try: + print("Start AOE optimization_level 1.") + compiled = mindietorch.compile(encoder_traced, + inputs=input_info, + precision_policy=mindietorch.PrecisionPolicy.FP16, + soc_version=args.soc_version, + optimization_level=1 + ) + except Exception: + print("Using Aoe to optimize encoder model failed, but still can compile encoder model.") + else: + print("AOE optimize finished and start compile encoder.") compiled = mindietorch.compile(encoder_traced, inputs=input_info, precision_policy=mindietorch.PrecisionPolicy.FP16, soc_version=args.soc_version, ) - save_file = os.path.join(args.save_path, f"{self.file_prefix_names[0]}{args.bs}.ts") + save_file = os.path.join(args.save_path, f"{FILE_REFIX_NAME[0]}{args.bs}.ts") torch.jit.save(compiled, save_file) print(f"Compile encoder success, saved in {save_file}") @@ -53,7 +64,7 @@ def compile_prefill_decoder(self, args): precision_policy=mindietorch.PrecisionPolicy.FP16, soc_version=args.soc_version) - save_file = os.path.join(args.save_path, f"{self.file_prefix_names[1]}{args.bs}.ts") + save_file = os.path.join(args.save_path, f"{FILE_REFIX_NAME[1]}{args.bs}.ts") torch.jit.save(prefill_decoder_compiled, save_file) print(f"Compile prefill_decoder success, saved in {save_file}.") @@ -109,7 +120,7 @@ def compile_incre_decoder(args): precision_policy=mindietorch.PrecisionPolicy.FP16, soc_version=args.soc_version, default_buffer_size_vec=[buffer_size]) - save_file = os.path.join(args.save_path, f"{self.file_prefix_names[2]}{args.bs}.ts") + save_file = os.path.join(args.save_path, f"{FILE_REFIX_NAME[2]}{args.bs}.ts") torch.jit.save(compiled_decoder, save_file) print(f"Compile whisper_decoder success, saved in {save_file}.") @@ -134,11 +145,11 @@ def compile_scatter_update(self, args): compile_self = mindietorch.compile(traced, inputs=[self_attn_info, indices_info, update_states_info], soc_version=args.soc_version) - torch.jit.save(compile_self, f"{args.save_path}/{self.file_prefix_names[3]}{args.bs}.ts") + torch.jit.save(compile_self, f"{args.save_path}/{FILE_REFIX_NAME[3]}{args.bs}.ts") compile_self = mindietorch.compile(traced, inputs=[encoder_attn_info, indices_info, update_states_info], soc_version=args.soc_version) - torch.jit.save(compile_self, f"{args.save_path}/{self.file_prefix_names[4]}{args.bs}.ts") + torch.jit.save(compile_self, f"{args.save_path}/{FILE_REFIX_NAME[4]}{args.bs}.ts") print("compile scatter success.") if __name__ == "__main__": diff --git a/MindIE/MindIE-Torch/built-in/audio/whisperX/modeling_whisper.py b/MindIE/MindIE-Torch/built-in/audio/whisperX/modeling_whisper.py index 87fada7d47..480e27db3c 100644 --- a/MindIE/MindIE-Torch/built-in/audio/whisperX/modeling_whisper.py +++ b/MindIE/MindIE-Torch/built-in/audio/whisperX/modeling_whisper.py @@ -218,9 +218,9 @@ class MindieFA(WhisperAttention): past_key_value = (key_states.to(torch.float16), value_states.to(torch.float16)) # MindFA only support BNSD layout - query_states = query_states.view(bsz, tgt_len, self.num_heads, self.head_dim).contiguous().transpose(1, 2) + query_states = query_states.view(bsz, tgt_len, self.num_heads, self.head_dim).contiguous() attn_output = torch.ops.aie.flash_attention( - query=query_states, + query=query_states.transpose(1, 2), key=key_states.transpose(1, 2), value=value_states.transpose(1, 2), num_head=self.num_heads, -- Gitee From ed46cbaddedfb2d73889363c5f55bbdd0843ac3e Mon Sep 17 00:00:00 2001 From: tiansongxue Date: Fri, 8 Nov 2024 17:22:20 +0800 Subject: [PATCH 40/44] fix prefix --- .../MindIE-Torch/built-in/audio/whisperX/compile_whisper.py | 5 +++-- .../MindIE-Torch/built-in/audio/whisperX/modeling_whisper.py | 2 +- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/MindIE/MindIE-Torch/built-in/audio/whisperX/compile_whisper.py b/MindIE/MindIE-Torch/built-in/audio/whisperX/compile_whisper.py index ee857fe6fa..a7c35007b5 100644 --- a/MindIE/MindIE-Torch/built-in/audio/whisperX/compile_whisper.py +++ b/MindIE/MindIE-Torch/built-in/audio/whisperX/compile_whisper.py @@ -6,7 +6,7 @@ import mindietorch from modeling_whisper import MindieWhisperForConditionalGeneration FILE_REFIX_NAME = ["mindie_whisper_encoder_bs", - "mindie_decoder_prefill_bs", + "mindie_decoder_prefill_bs", "mindie_whisper_decoder_bs", "mindie_self_scatter_bs", "mindie_encoder_scatter_bs" @@ -162,7 +162,8 @@ if __name__ == "__main__": parser.add_argument('-device_id', type=int, default=0) args = parser.parse_args() - mindie_whisper = MindieWhisperForConditionalGeneration.from_pretrained(args.model_path, args.soc_version) + mindie_whisper = MindieWhisperForConditionalGeneration.from_pretrained(args.model_path, + soc_version=args.soc_version) device = f"npu:{args.device_id}" print("Start compiling Mindie-Whisper, it will take some time, please wait.") if not args.save_path: diff --git a/MindIE/MindIE-Torch/built-in/audio/whisperX/modeling_whisper.py b/MindIE/MindIE-Torch/built-in/audio/whisperX/modeling_whisper.py index 480e27db3c..4118819257 100644 --- a/MindIE/MindIE-Torch/built-in/audio/whisperX/modeling_whisper.py +++ b/MindIE/MindIE-Torch/built-in/audio/whisperX/modeling_whisper.py @@ -985,4 +985,4 @@ class MindieWhisperForConditionalGeneration(WhisperForConditionalGeneration, Gen return { "encoder_outputs": encoder_outputs, "decoder_input_ids": decoder_input_ids - } + } \ No newline at end of file -- Gitee From 01c453c9e3413ea690b15a13a9fc3ceba7101ed2 Mon Sep 17 00:00:00 2001 From: tiansongxue Date: Mon, 11 Nov 2024 10:15:19 +0800 Subject: [PATCH 41/44] fix prefix --- .../audio/whisperX/compile_whisper.py | 71 +++++++++---------- .../audio/whisperX/modeling_whisper.py | 7 +- .../built-in/audio/whisper_large_v3/config.py | 19 +++++ 3 files changed, 57 insertions(+), 40 deletions(-) create mode 100644 MindIE/MindIE-Torch/built-in/audio/whisper_large_v3/config.py diff --git a/MindIE/MindIE-Torch/built-in/audio/whisperX/compile_whisper.py b/MindIE/MindIE-Torch/built-in/audio/whisperX/compile_whisper.py index a7c35007b5..611dce65a9 100644 --- a/MindIE/MindIE-Torch/built-in/audio/whisperX/compile_whisper.py +++ b/MindIE/MindIE-Torch/built-in/audio/whisperX/compile_whisper.py @@ -3,32 +3,27 @@ import argparse import time import torch import mindietorch +from mindietorch._enums import dtype from modeling_whisper import MindieWhisperForConditionalGeneration +from config import CompileInfo -FILE_REFIX_NAME = ["mindie_whisper_encoder_bs", - "mindie_decoder_prefill_bs", - "mindie_whisper_decoder_bs", - "mindie_self_scatter_bs", - "mindie_encoder_scatter_bs" - ] def compile_encoder(model : MindieWhisperForConditionalGeneration, - args : argparse): + args : argparse, + compile_info : CompileInfo): encoder = model.get_encoder() class Encoder(torch.nn.Module): def __init__(self, model): super().__init__() - self.model = model + compile_info.model = model def forward(self, input_features): - return self.model(input_features=input_features, return_dict=False) + return compile_info.model(input_features=input_features, return_dict=False) - mel_feature_size = 128 - max_frames = 3000 - input_features = torch.randn([args.bs, mel_feature_size, max_frames]) - encoder_traced = torch.jit.trace(Encoder(encoder), (input_features)) - input_info = [mindietorch.Input(shape=(args.bs, mel_feature_size, max_frames))] + input_features = torch.randn([args.bs, compile_info.mel_feature_size, compile_info.max_frames]) + encoder_traced = torch.jit.trace(Encoder(encoder), (compile_info.input_features)) + input_info = [mindietorch.Input(shape=(args.bs, compile_info.mel_feature_size, compile_info.max_frames))] try: print("Start AOE optimization_level 1.") compiled = mindietorch.compile(encoder_traced, @@ -50,13 +45,12 @@ def compile_encoder(model : MindieWhisperForConditionalGeneration, torch.jit.save(compiled, save_file) print(f"Compile encoder success, saved in {save_file}") -def compile_prefill_decoder(self, args): +def compile_prefill_decoder(model : MindieWhisperForConditionalGeneration, + args : argparse, compile_info : CompileInfo): print("Start compiling prefill_decoder.") - encoder_seq_len = 1500 - hidden_size = 1280 - encoder_outputs = torch.randn([args.bs, encoder_seq_len, hidden_size]) + encoder_outputs = torch.randn([args.bs, compile_info.encoder_seq_len, compile_info.hidden_size]) decoder_input_ids = torch.randint(1, 4, (args.bs, 1)) - prefill_decoder_traced = torch.jit.trace(self.eval(), (decoder_input_ids, encoder_outputs)) + prefill_decoder_traced = torch.jit.trace(model.eval(), (decoder_input_ids, encoder_outputs)) input_info = [mindietorch.Input(shape=(args.bs, 1), dtype=dtype.INT64), mindietorch.Input(shape=(args.bs, encoder_seq_len, hidden_size))] prefill_decoder_compiled = mindietorch.compile(prefill_decoder_traced, @@ -68,7 +62,7 @@ def compile_prefill_decoder(self, args): torch.jit.save(prefill_decoder_compiled, save_file) print(f"Compile prefill_decoder success, saved in {save_file}.") -def compile_incre_decoder(args): +def compile_incre_decoder(args : argparse, compile_info : CompileInfo): class Decoder(torch.nn.Module): def __init__(self, model): @@ -79,35 +73,36 @@ def compile_incre_decoder(args): return self.model.forward(*args)[0] mindie_whisper = MindieWhisperForConditionalGeneration.from_pretrained(args.model_path, + machine_type = args.machine_type, is_incre_decode=True, soc_version=args.soc_version) decoder = Decoder(mindie_whisper) print("Start compiling decoder.") - encoder_outputs = torch.randn([args.bs, self.encoder_seq_len, self.hidden_size]) + encoder_outputs = torch.randn([args.bs, compile_info.encoder_seq_len, compile_info.hidden_size]) decoder_input_ids = torch.randint(1, 4, (args.bs, 1)) actual_seq_len = torch.ones((args.bs)) - all_past_key_value = [torch.randn([args.bs, self.seq_len, self.head_num, self.head_size]), - torch.randn([args.bs, self.seq_len, self.head_num, self.head_size]), - torch.randn([args.bs, self.encoder_seq_len, self.head_num, self.head_size]), - torch.randn([args.bs, self.encoder_seq_len, self.head_num, self.head_size]) - ] * self.layer_nums + all_past_key_value = [torch.randn([args.bs, compile_info.max_decode_step, compile_info.head_num, compile_info.head_size]), + torch.randn([args.bs, compile_info.seq_len, compile_info.head_num, compile_info.head_size]), + torch.randn([args.bs, compile_info.encoder_seq_len, compile_info.head_num, compile_info.head_size]), + torch.randn([args.bs, compile_info.encoder_seq_len, compile_info.head_num, compile_info.head_size]) + ] * compile_info.layer_nums traced_args = [decoder_input_ids, encoder_outputs, actual_seq_len] traced_args.extend(all_past_key_value) traced_decoder = torch.jit.trace(decoder, traced_args) # BSND key_value_infos = [ - mindietorch.Input(shape=(args.bs, self.max_len, self.head_num, self.head_size), + mindietorch.Input(shape=(args.bs, compile_info.max_len, compile_info.head_num, compile_info.head_size), dtype=dtype.FLOAT16), - mindietorch.Input(shape=(args.bs, self.max_len, self.head_num, self.head_size), + mindietorch.Input(shape=(args.bs, compile_info.max_len, compile_info.head_num, compile_info.head_size), dtype=dtype.FLOAT16), - mindietorch.Input(shape=(args.bs, self.encoder_seq_len, self.head_num, self.head_size), + mindietorch.Input(shape=(args.bs, compile_info.encoder_seq_len, compile_info.head_num, compile_info.head_size), dtype=dtype.FLOAT16), - mindietorch.Input(shape=(args.bs, self.encoder_seq_len, self.head_num, self.head_size), + mindietorch.Input(shape=(args.bs, compile_info.le_info.encoder_seq_len, compile_info.head_num, compile_info.head_size), dtype=dtype.FLOAT16 - )] * self.layer_nums + )] * compile_info.layer_nums input_info = [mindietorch.Input(shape=(args.bs, 1), dtype=dtype.INT64), # input ids - mindietorch.Input(shape=(args.bs, self.encoder_seq_len, self.hidden_size)), + mindietorch.Input(shape=(args.bs, compile_info.le_info.encoder_seq_len, compile_info.hidden_size)), mindietorch.Input(shape=(args.bs,), dtype=dtype.INT64)] # actual sq len input_info.extend(key_value_infos) @@ -132,8 +127,8 @@ def compile_scatter_update(self, args): return out bs = args.bs - self_past_key_value = torch.randn([bs, self.max_len, self.head_num, self.head_size]) - encoder_past_key_value = torch.randn([bs, self.max_len, self.head_num, self.head_size]) + self_past_key_value = torch.randn([bs, compile_info.max_len, compile_info.head_num, compile_info.head_size]) + encoder_past_key_value = torch.randn([bs, compile_info.max_len, compile_info.head_num, compile_info.head_size]) indices = torch.tensor([0] * bs) update_states = torch.randn([bs, 1, 20, 64]) traced = torch.jit.trace(MindieScatter(), (self_past_key_value, indices, update_states)) @@ -156,14 +151,16 @@ if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument('-model_path', type=str, required=True, help="please provide model path.") parser.add_argument('-bs', type=int, default=8, help="please provide batch_size, default:8.") - parser.add_argument('-soc_version', type=str, choices=["Ascend310P3", "Ascend910B4"], default="Ascend310P3", - help="please provide soc_version, default:Ascend310P3.") + parser.add_argument('-soc_version', type=str, required=True, + help="please provide soc_version.") parser.add_argument('-save_path', type=str, default="compiled_models", help="compiled models save dir.") + parser.add_argument('-machine_type', type=str, choices=["300I", "800A2"], default="800A2") parser.add_argument('-device_id', type=int, default=0) args = parser.parse_args() mindie_whisper = MindieWhisperForConditionalGeneration.from_pretrained(args.model_path, - soc_version=args.soc_version) + soc_version=args.soc_version, + machine_type=machine_type) device = f"npu:{args.device_id}" print("Start compiling Mindie-Whisper, it will take some time, please wait.") if not args.save_path: diff --git a/MindIE/MindIE-Torch/built-in/audio/whisperX/modeling_whisper.py b/MindIE/MindIE-Torch/built-in/audio/whisperX/modeling_whisper.py index 4118819257..80e37c8edb 100644 --- a/MindIE/MindIE-Torch/built-in/audio/whisperX/modeling_whisper.py +++ b/MindIE/MindIE-Torch/built-in/audio/whisperX/modeling_whisper.py @@ -306,7 +306,7 @@ class MindieWhisperDecoderLayer(WhisperDecoderLayer): is_decoder=True, config=config ) - elif config.soc_version == "Ascend910B4": + elif config.soc_version == "800A2": self.self_attn = MindiePFA( embed_dim=self.embed_dim, num_heads=config.decoder_attention_heads, @@ -319,7 +319,7 @@ class MindieWhisperDecoderLayer(WhisperDecoderLayer): is_decoder=True, config=config ) - elif config.soc_version == "Ascend310P3": + elif config.soc_version == "300IPro": self.self_attn = MindieFA( embed_dim=self.embed_dim, num_heads=config.decoder_attention_heads, @@ -517,10 +517,11 @@ class MindieWhisperModel(WhisperModel): class MindieWhisperForConditionalGeneration(WhisperForConditionalGeneration, GenerationMixin): - def __init__(self, config, is_incre_decode=False, soc_version="Ascend310P3"): + def __init__(self, config, is_incre_decode=False, soc_version="Ascend310P3", machine_type="800A2"): super().__init__(config) config.is_incre_decode = is_incre_decode config.soc_version = soc_version + config.machine_type = machine_type self.model = MindieWhisperModel(config) self.has_load = False self.has_compile = False diff --git a/MindIE/MindIE-Torch/built-in/audio/whisper_large_v3/config.py b/MindIE/MindIE-Torch/built-in/audio/whisper_large_v3/config.py new file mode 100644 index 0000000000..08d322455a --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/audio/whisper_large_v3/config.py @@ -0,0 +1,19 @@ + +from dataclasses import dataclass, field +from typing import List + +@dataclass +class CompileInfo: + prefix_name = ["mindie_whisper_encoder_bs", + "mindie_decoder_prefill_bs", + "mindie_whisper_decoder_bs", + "mindie_self_scatter_bs", + "mindie_encoder_scatter_bs"] + mel_feature_size = 128 + max_frames = 3000 + max_decode_step = 448 + head_num = 20 + head_size = 64 + encoder_seq_len = 1500 + hidden_size = 1280 + layer_nums = 32 -- Gitee From c1b15446e9219d41dd38910fe5302c03daa962d2 Mon Sep 17 00:00:00 2001 From: tiansongxue Date: Mon, 11 Nov 2024 10:17:50 +0800 Subject: [PATCH 42/44] fix prefix --- .../built-in/audio/whisperX/compile_whisper.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/MindIE/MindIE-Torch/built-in/audio/whisperX/compile_whisper.py b/MindIE/MindIE-Torch/built-in/audio/whisperX/compile_whisper.py index 611dce65a9..676b851ffe 100644 --- a/MindIE/MindIE-Torch/built-in/audio/whisperX/compile_whisper.py +++ b/MindIE/MindIE-Torch/built-in/audio/whisperX/compile_whisper.py @@ -41,7 +41,7 @@ def compile_encoder(model : MindieWhisperForConditionalGeneration, precision_policy=mindietorch.PrecisionPolicy.FP16, soc_version=args.soc_version, ) - save_file = os.path.join(args.save_path, f"{FILE_REFIX_NAME[0]}{args.bs}.ts") + save_file = os.path.join(args.save_path, f"{compile_info.prefix_nameefix_name[0]}{args.bs}.ts") torch.jit.save(compiled, save_file) print(f"Compile encoder success, saved in {save_file}") @@ -58,7 +58,7 @@ def compile_prefill_decoder(model : MindieWhisperForConditionalGeneration, precision_policy=mindietorch.PrecisionPolicy.FP16, soc_version=args.soc_version) - save_file = os.path.join(args.save_path, f"{FILE_REFIX_NAME[1]}{args.bs}.ts") + save_file = os.path.join(args.save_path, f"{compile_info.prefix_nameefix_name[1]}{args.bs}.ts") torch.jit.save(prefill_decoder_compiled, save_file) print(f"Compile prefill_decoder success, saved in {save_file}.") @@ -115,7 +115,7 @@ def compile_incre_decoder(args : argparse, compile_info : CompileInfo): precision_policy=mindietorch.PrecisionPolicy.FP16, soc_version=args.soc_version, default_buffer_size_vec=[buffer_size]) - save_file = os.path.join(args.save_path, f"{FILE_REFIX_NAME[2]}{args.bs}.ts") + save_file = os.path.join(args.save_path, f"{compile_info.prefix_nameefix_nameefix_nameefix_name[2]}{args.bs}.ts") torch.jit.save(compiled_decoder, save_file) print(f"Compile whisper_decoder success, saved in {save_file}.") @@ -140,11 +140,11 @@ def compile_scatter_update(self, args): compile_self = mindietorch.compile(traced, inputs=[self_attn_info, indices_info, update_states_info], soc_version=args.soc_version) - torch.jit.save(compile_self, f"{args.save_path}/{FILE_REFIX_NAME[3]}{args.bs}.ts") + torch.jit.save(compile_self, f"{args.save_path}/{compile_info.prefix_name[3]}{args.bs}.ts") compile_self = mindietorch.compile(traced, inputs=[encoder_attn_info, indices_info, update_states_info], soc_version=args.soc_version) - torch.jit.save(compile_self, f"{args.save_path}/{FILE_REFIX_NAME[4]}{args.bs}.ts") + torch.jit.save(compile_self, f"{args.save_path}/{compile_info.prefix_name[4]}{args.bs}.ts") print("compile scatter success.") if __name__ == "__main__": -- Gitee From 67c8d577edc759218cd18b955589c0669282fbc9 Mon Sep 17 00:00:00 2001 From: tiansongxue Date: Mon, 11 Nov 2024 10:22:34 +0800 Subject: [PATCH 43/44] fix prefix --- .../MindIE-Torch/built-in/audio/whisperX/compile_whisper.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/MindIE/MindIE-Torch/built-in/audio/whisperX/compile_whisper.py b/MindIE/MindIE-Torch/built-in/audio/whisperX/compile_whisper.py index 676b851ffe..7ac2b15ca0 100644 --- a/MindIE/MindIE-Torch/built-in/audio/whisperX/compile_whisper.py +++ b/MindIE/MindIE-Torch/built-in/audio/whisperX/compile_whisper.py @@ -16,10 +16,10 @@ def compile_encoder(model : MindieWhisperForConditionalGeneration, def __init__(self, model): super().__init__() - compile_info.model = model + self.model = model def forward(self, input_features): - return compile_info.model(input_features=input_features, return_dict=False) + return self.model(input_features=input_features, return_dict=False) input_features = torch.randn([args.bs, compile_info.mel_feature_size, compile_info.max_frames]) encoder_traced = torch.jit.trace(Encoder(encoder), (compile_info.input_features)) -- Gitee From abf99a5b30ab89abc1fdf70708483e33ecaa4ac1 Mon Sep 17 00:00:00 2001 From: tiansongxue Date: Mon, 11 Nov 2024 11:12:33 +0800 Subject: [PATCH 44/44] fix prefix --- .../audio/whisperX/compile_whisper.py | 20 +++++++++---------- .../audio/whisperX/modeling_whisper.py | 11 +++++----- 2 files changed, 15 insertions(+), 16 deletions(-) diff --git a/MindIE/MindIE-Torch/built-in/audio/whisperX/compile_whisper.py b/MindIE/MindIE-Torch/built-in/audio/whisperX/compile_whisper.py index 7ac2b15ca0..3f4bfd3081 100644 --- a/MindIE/MindIE-Torch/built-in/audio/whisperX/compile_whisper.py +++ b/MindIE/MindIE-Torch/built-in/audio/whisperX/compile_whisper.py @@ -22,7 +22,7 @@ def compile_encoder(model : MindieWhisperForConditionalGeneration, return self.model(input_features=input_features, return_dict=False) input_features = torch.randn([args.bs, compile_info.mel_feature_size, compile_info.max_frames]) - encoder_traced = torch.jit.trace(Encoder(encoder), (compile_info.input_features)) + encoder_traced = torch.jit.trace(Encoder(encoder), (input_features)) input_info = [mindietorch.Input(shape=(args.bs, compile_info.mel_feature_size, compile_info.max_frames))] try: print("Start AOE optimization_level 1.") @@ -120,15 +120,15 @@ def compile_incre_decoder(args : argparse, compile_info : CompileInfo): print(f"Compile whisper_decoder success, saved in {save_file}.") -def compile_scatter_update(self, args): +def compile_scatter_update(self, args, compile_info): class MindieScatter(torch.nn.Module): def forward(self, past_key_value, indices, update_states): out = torch.ops.aie.scatter_update(past_key_value, indices, update_states, 1) return out bs = args.bs - self_past_key_value = torch.randn([bs, compile_info.max_len, compile_info.head_num, compile_info.head_size]) - encoder_past_key_value = torch.randn([bs, compile_info.max_len, compile_info.head_num, compile_info.head_size]) + self_past_key_value = torch.randn([bs, compile_info.max_decode_step, compile_info.head_num, compile_info.head_size]) + encoder_past_key_value = torch.randn([bs, compile_info.max_decode_step, compile_info.head_num, compile_info.head_size]) indices = torch.tensor([0] * bs) update_states = torch.randn([bs, 1, 20, 64]) traced = torch.jit.trace(MindieScatter(), (self_past_key_value, indices, update_states)) @@ -154,13 +154,13 @@ if __name__ == "__main__": parser.add_argument('-soc_version', type=str, required=True, help="please provide soc_version.") parser.add_argument('-save_path', type=str, default="compiled_models", help="compiled models save dir.") - parser.add_argument('-machine_type', type=str, choices=["300I", "800A2"], default="800A2") + parser.add_argument('-machine_type', type=str, choices=["300IPro", "800IA2"], default="800A2") parser.add_argument('-device_id', type=int, default=0) args = parser.parse_args() mindie_whisper = MindieWhisperForConditionalGeneration.from_pretrained(args.model_path, soc_version=args.soc_version, - machine_type=machine_type) + machine_type=args.machine_type) device = f"npu:{args.device_id}" print("Start compiling Mindie-Whisper, it will take some time, please wait.") if not args.save_path: @@ -171,7 +171,7 @@ if __name__ == "__main__": else: print(f"Directory {args.save_path} already exists.") - compile_scatter_update(mindie_whisper, args) - compile_encoder(mindie_whisper, args) - compile_prefill_decoder(mindie_whisper, args) - compile_incre_decoder(args) \ No newline at end of file + compile_scatter_update(mindie_whisper, args, CompileInfo) + compile_encoder(mindie_whisper, args, CompileInfo) + compile_prefill_decoder(mindie_whisper, args, CompileInfo) + compile_incre_decoder(args, CompileInfo) \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/audio/whisperX/modeling_whisper.py b/MindIE/MindIE-Torch/built-in/audio/whisperX/modeling_whisper.py index 80e37c8edb..639c847349 100644 --- a/MindIE/MindIE-Torch/built-in/audio/whisperX/modeling_whisper.py +++ b/MindIE/MindIE-Torch/built-in/audio/whisperX/modeling_whisper.py @@ -306,7 +306,7 @@ class MindieWhisperDecoderLayer(WhisperDecoderLayer): is_decoder=True, config=config ) - elif config.soc_version == "800A2": + elif config.machine_type == "800IA2": self.self_attn = MindiePFA( embed_dim=self.embed_dim, num_heads=config.decoder_attention_heads, @@ -319,7 +319,7 @@ class MindieWhisperDecoderLayer(WhisperDecoderLayer): is_decoder=True, config=config ) - elif config.soc_version == "300IPro": + elif config.machine_type == "300IPro": self.self_attn = MindieFA( embed_dim=self.embed_dim, num_heads=config.decoder_attention_heads, @@ -334,7 +334,7 @@ class MindieWhisperDecoderLayer(WhisperDecoderLayer): ) else: raise ValueError(f"Unsupporting current parameters. eg. " - f"is_incre_decode: {is_incre_decode} soc_version {soc_version}") + f"is_incre_decode: {is_incre_decode} machine_type {machine_type}") def forward( self, @@ -517,10 +517,9 @@ class MindieWhisperModel(WhisperModel): class MindieWhisperForConditionalGeneration(WhisperForConditionalGeneration, GenerationMixin): - def __init__(self, config, is_incre_decode=False, soc_version="Ascend310P3", machine_type="800A2"): + def __init__(self, config, is_incre_decode=False, machine_type="800A2"): super().__init__(config) config.is_incre_decode = is_incre_decode - config.soc_version = soc_version config.machine_type = machine_type self.model = MindieWhisperModel(config) self.has_load = False @@ -632,7 +631,7 @@ class MindieWhisperForConditionalGeneration(WhisperForConditionalGeneration, Gen return self.model.forward(*args)[0] mindie_whisper = MindieWhisperForConditionalGeneration.from_pretrained(args.model_path, - enable_incre_flash_attention=True) + is_incre_decode=True) decoder = Decoder(mindie_whisper) print("Start compiling decoder.") -- Gitee