diff --git a/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/300I/diff_CosyVoice_300I.patch b/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/300I/diff_CosyVoice_300I.patch new file mode 100644 index 0000000000000000000000000000000000000000..a760ffd311959ddbbaf88b20a60392ae9754df6a --- /dev/null +++ b/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/300I/diff_CosyVoice_300I.patch @@ -0,0 +1,693 @@ +diff --git a/cosyvoice/cli/cosyvoice.py b/cosyvoice/cli/cosyvoice.py +index e2d62e2..0d4f860 100644 +--- a/cosyvoice/cli/cosyvoice.py ++++ b/cosyvoice/cli/cosyvoice.py +@@ -13,11 +13,14 @@ + # limitations under the License. + import os + import time ++import platform ++import datetime + from typing import Generator + from tqdm import tqdm + from hyperpyyaml import load_hyperpyyaml + from modelscope import snapshot_download + import torch ++from ais_bench.infer.interface import InferSession + from cosyvoice.cli.frontend import CosyVoiceFrontEnd + from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model + from cosyvoice.utils.file_utils import logging +@@ -68,9 +71,12 @@ class CosyVoice: + model_input = self.frontend.frontend_sft(i, spk_id) + start_time = time.time() + logging.info('synthesis text {}'.format(i)) +- for model_output in self.model.tts(**model_input, stream=stream, speed=speed): ++ for idx, model_output in enumerate(self.model.tts(**model_input, stream=stream, speed=speed)): + speech_len = model_output['tts_speech'].shape[1] / self.sample_rate +- logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len)) ++ if idx == 0: ++ logging.info('yield speech len {}, rtf {}, TTFT {}'.format(speech_len, (time.time() - start_time) / speech_len, time.time() - start_time)) ++ else: ++ logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len)) + yield model_output + start_time = time.time() + +@@ -82,9 +88,12 @@ class CosyVoice: + model_input = self.frontend.frontend_zero_shot(i, prompt_text, prompt_speech_16k, self.sample_rate) + start_time = time.time() + logging.info('synthesis text {}'.format(i)) +- for model_output in self.model.tts(**model_input, stream=stream, speed=speed): ++ for idx, model_output in enumerate(self.model.tts(**model_input, stream=stream, speed=speed)): + speech_len = model_output['tts_speech'].shape[1] / self.sample_rate +- logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len)) ++ if idx == 0: ++ logging.info('yield speech len {}, rtf {}, TTFT {}'.format(speech_len, (time.time() - start_time) / speech_len, time.time() - start_time)) ++ else: ++ logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len)) + yield model_output + start_time = time.time() + +@@ -126,7 +135,7 @@ class CosyVoice: + + class CosyVoice2(CosyVoice): + +- def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False): ++ def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False, load_om=False): + self.instruct = True if '-Instruct' in model_dir else False + self.model_dir = model_dir + self.fp16 = fp16 +@@ -155,6 +164,16 @@ class CosyVoice2(CosyVoice): + self.model.load_trt('{}/flow.decoder.estimator.{}.mygpu.plan'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'), + '{}/flow.decoder.estimator.fp32.onnx'.format(model_dir), + self.fp16) ++ if load_om: ++ arch = platform.machine() ++ system = platform.system().lower() ++ flow_om = InferSession(0, '{}/flow_{}_{}.om'.format(model_dir, system ,arch)) ++ flow_om_static = InferSession(0, '{}/flow_static.om'.format(model_dir)) ++ speech_om = InferSession(0, '{}/speech_{}_{}.om'.format(model_dir, system ,arch)) ++ self.frontend.speech_om = speech_om ++ self.frontend.flow_om = flow_om ++ self.model.flow.decoder.flow_om_static = flow_om_static ++ self.model.flow.decoder.flow_om = flow_om + del configs + + def inference_instruct(self, *args, **kwargs): +@@ -171,3 +190,19 @@ class CosyVoice2(CosyVoice): + logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len)) + yield model_output + start_time = time.time() ++ ++ def inference_sft_streaming_input(self, tts_text, char_idx, spk_id, user_id, input_end, stream=False, speed=1.0, text_frontend=True): ++ for i in [tts_text]: ++ model_input = self.frontend.frontend_sft(i, spk_id) ++ model_input["user_id"] = user_id ++ model_input["input_end"] = input_end ++ model_input['char_idx'] = char_idx ++ ++ start_time = time.time() ++ # print('synthesis text {}'.format(i)) ++ for model_output in self.model.tts_streaming_input(**model_input, stream=stream, speed=speed): ++ speech_len = model_output['tts_speech'].shape[1] / self.sample_rate ++ print("finish 1 chunk inference ", datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S.%f')) ++ logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len)) ++ yield model_output ++ start_time = time.time() +diff --git a/cosyvoice/cli/frontend.py b/cosyvoice/cli/frontend.py +index 6e10f00..25ad767 100644 +--- a/cosyvoice/cli/frontend.py ++++ b/cosyvoice/cli/frontend.py +@@ -71,6 +71,8 @@ class CosyVoiceFrontEnd: + self.zh_tn_model = ZhNormalizer(remove_erhua=False, full_to_half=False, overwrite_cache=True) + self.en_tn_model = EnNormalizer() + self.inflect_parser = inflect.engine() ++ self.speech_om = None ++ self.flow_om = None + + def _extract_text_token(self, text): + if isinstance(text, Generator): +@@ -92,11 +94,16 @@ class CosyVoiceFrontEnd: + def _extract_speech_token(self, speech): + assert speech.shape[1] / 16000 <= 30, 'do not support extract speech token for audio longer than 30s' + feat = whisper.log_mel_spectrogram(speech, n_mels=128) +- speech_token = self.speech_tokenizer_session.run(None, +- {self.speech_tokenizer_session.get_inputs()[0].name: +- feat.detach().cpu().numpy(), +- self.speech_tokenizer_session.get_inputs()[1].name: +- np.array([feat.shape[2]], dtype=np.int32)})[0].flatten().tolist() ++ if torch.npu.is_available() and self.speech_om: ++ feed = [feat.detach().cpu().numpy(), np.array([feat.shape[2]], dtype=np.int32)] ++ speech_token = self.speech_om.infer(feed, mode='dymshape', custom_sizes=[100000000])[0].flatten().tolist() ++ self.flow_om.set_context() ++ else: ++ speech_token = self.speech_tokenizer_session.run(None, ++ {self.speech_tokenizer_session.get_inputs()[0].name: ++ feat.detach().cpu().numpy(), ++ self.speech_tokenizer_session.get_inputs()[1].name: ++ np.array([feat.shape[2]], dtype=np.int32)})[0].flatten().tolist() + speech_token = torch.tensor([speech_token], dtype=torch.int32).to(self.device) + speech_token_len = torch.tensor([speech_token.shape[1]], dtype=torch.int32).to(self.device) + return speech_token, speech_token_len +diff --git a/cosyvoice/cli/model.py b/cosyvoice/cli/model.py +index 9ebf8cb..407f1ae 100644 +--- a/cosyvoice/cli/model.py ++++ b/cosyvoice/cli/model.py +@@ -14,6 +14,7 @@ + import os + from typing import Generator + import torch ++import torch_npu + import numpy as np + import threading + import time +@@ -99,7 +100,7 @@ class CosyVoiceModel: + self.flow.decoder.estimator = self.flow.decoder.estimator_engine.create_execution_context() + + def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid): +- with self.llm_context: ++ with self.llm_context(): + if isinstance(text, Generator): + assert isinstance(self, CosyVoice2Model), 'streaming input text is only implemented for CosyVoice2!' + for i in self.llm.inference_bistream(text=text, +@@ -307,13 +308,25 @@ class CosyVoice2Model(CosyVoiceModel): + self.speech_window = np.hamming(2 * self.source_cache_len) + # rtf and decoding related + self.stream_scale_factor = 1 +- self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext() ++ if torch.cuda.is_available(): ++ stream = torch.cuda.Stream(device=self.device) ++ self.llm_context = lambda: torch.cuda.stream(stream) ++ else: ++ self.llm_context = lambda: contextlib.nullcontext() + self.lock = threading.Lock() + # dict used to store session related variable + self.tts_speech_token_dict = {} + self.llm_end_dict = {} + self.hift_cache_dict = {} + ++ # add for support streaming input ++ self.first_chunk_size = 20 ++ self.token_offset_dict = {} ++ self.prompt_text_dict = {} ++ self.prompt_speech_token_dict = {} ++ self.speech_feat_dict = {} ++ self.embedding_dict = {} ++ + def load_jit(self, flow_encoder_model): + flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device) + self.flow.encoder = flow_encoder +@@ -409,3 +422,83 @@ class CosyVoice2Model(CosyVoiceModel): + self.tts_speech_token_dict.pop(this_uuid) + self.llm_end_dict.pop(this_uuid) + torch.cuda.empty_cache() ++ ++ def tts_streaming_input(self, text, char_idx, flow_embedding, llm_embedding=torch.zeros(0, 192), ++ prompt_text=torch.zeros(1, 0, dtype=torch.int32), ++ llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32), ++ flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32), ++ prompt_speech_feat=torch.zeros(1, 0, 80), stream=False, speed=1.0, **kwargs): ++ this_uuid = kwargs.get("user_id", "AscendDefaultUser") ++ if this_uuid not in self.tts_speech_token_dict: ++ self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False ++ self.hift_cache_dict[this_uuid] = None ++ self.token_offset_dict[this_uuid] = 0 ++ ++ self.prompt_text_dict[this_uuid] = prompt_text ++ self.prompt_speech_token_dict[this_uuid] = flow_prompt_speech_token ++ self.speech_feat_dict[this_uuid] = prompt_speech_feat ++ self.embedding_dict[this_uuid] = flow_embedding ++ else: ++ prompt_text = self.prompt_text_dict[this_uuid] ++ llm_prompt_speech_token = self.prompt_speech_token_dict[this_uuid] ++ flow_prompt_speech_token = self.prompt_speech_token_dict[this_uuid] ++ flow_embedding = self.embedding_dict[this_uuid] ++ llm_embedding = self.embedding_dict[this_uuid] ++ prompt_speech_feat = self.speech_feat_dict[this_uuid] ++ ++ for i in self.llm.inference_bistream_streaming_input(text=text, ++ char_idx=torch.tensor([char_idx]).to(self.device), ++ prompt_text=prompt_text.to(self.device), ++ prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device), ++ prompt_speech_token=llm_prompt_speech_token.to(self.device), ++ prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device), ++ embedding=llm_embedding.to(self.device), ++ uuid=this_uuid, input_end=kwargs['input_end']): ++ self.tts_speech_token_dict[this_uuid].append(i) ++ ++ assert stream is True, "output must be streaming" ++ ++ while True: ++ is_first_chunk_ready = (self.token_offset_dict[this_uuid] == 0 and len(self.tts_speech_token_dict[this_uuid]) >= self.first_chunk_size + self.flow.pre_lookahead_len) ++ is_next_chunk_ready = (self.token_offset_dict[this_uuid] > 0 and len(self.tts_speech_token_dict[this_uuid]) - self.token_offset_dict[this_uuid] >= self.token_hop_len + self.flow.pre_lookahead_len) ++ if is_first_chunk_ready or is_next_chunk_ready: ++ if self.token_offset_dict[this_uuid] == 0: ++ this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:self.first_chunk_size + self.flow.pre_lookahead_len]).unsqueeze(dim=0) ++ else: ++ this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:self.token_offset_dict[this_uuid] + self.token_hop_len + self.flow.pre_lookahead_len]).unsqueeze(dim=0) # 0-53, 0-103, 0-153... ++ this_tts_speech = self.token2wav(token=this_tts_speech_token, ++ prompt_token=flow_prompt_speech_token, ++ prompt_feat=prompt_speech_feat, ++ embedding=flow_embedding, ++ uuid=this_uuid, ++ token_offset=self.token_offset_dict[this_uuid], ++ finalize=False) ++ if self.token_offset_dict[this_uuid] == 0: ++ self.token_offset_dict[this_uuid] += self.first_chunk_size ++ else: ++ self.token_offset_dict[this_uuid] += self.token_hop_len ++ yield {'tts_speech': this_tts_speech.cpu()} ++ # 是否需要退出循环(token 不够下一次推理) ++ if len(self.tts_speech_token_dict[this_uuid]) - self.token_offset_dict[this_uuid] < self.token_hop_len + self.flow.pre_lookahead_len: ++ break ++ ++ if kwargs['input_end'] is True: ++ this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0) ++ this_tts_speech = self.token2wav(token=this_tts_speech_token, ++ prompt_token=flow_prompt_speech_token, ++ prompt_feat=prompt_speech_feat, ++ embedding=flow_embedding, ++ uuid=this_uuid, ++ token_offset=self.token_offset_dict[this_uuid], ++ finalize=True) ++ yield {'tts_speech': this_tts_speech.cpu()} ++ ++ self.tts_speech_token_dict.pop(this_uuid) ++ self.llm_end_dict.pop(this_uuid) ++ self.hift_cache_dict.pop(this_uuid) ++ ++ self.token_offset_dict.pop(this_uuid) ++ self.prompt_text_dict.pop(this_uuid) ++ self.prompt_speech_token_dict.pop(this_uuid) ++ self.speech_feat_dict.pop(this_uuid) ++ self.embedding_dict.pop(this_uuid) +diff --git a/cosyvoice/flow/flow_matching.py b/cosyvoice/flow/flow_matching.py +index 6a60f6d..fbe7545 100644 +--- a/cosyvoice/flow/flow_matching.py ++++ b/cosyvoice/flow/flow_matching.py +@@ -14,6 +14,7 @@ + import threading + import torch + import torch.nn.functional as F ++import numpy as np + from matcha.models.components.flow_matching import BASECFM + + +@@ -32,6 +33,8 @@ class ConditionalCFM(BASECFM): + # Just change the architecture of the estimator here + self.estimator = estimator + self.lock = threading.Lock() ++ self.flow_om = None ++ self.flow_om_static = None + + @torch.inference_mode() + def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, prompt_len=0, flow_cache=torch.zeros(1, 80, 0, 2)): +@@ -105,12 +108,26 @@ class ConditionalCFM(BASECFM): + t_in[:] = t.unsqueeze(0) + spks_in[0] = spks + cond_in[0] = cond +- dphi_dt = self.forward_estimator( +- x_in, mask_in, +- mu_in, t_in, +- spks_in, +- cond_in +- ) ++ # 动态分档推理, 在流式输出中,每次输出的token数目固定,可以采取动态分档模型执行推理 ++ if torch.npu.is_available() and self.flow_om_static and x.size(2)%100==0 and x.size(2)<800: ++ feed_list = [x_in, mask_in, mu_in, t_in, spks_in, cond_in] ++ feed = [i.cpu().detach().numpy().astype(np.float32) for i in feed_list] ++ dphi_dt = self.flow_om_static.infer(feed, mode="dymdims") ++ self.flow_om.set_context() ++ dphi_dt = torch.from_numpy(dphi_dt[0]).npu() ++ # 输出的token数目不固定场景采用动态模型推理 ++ elif torch.npu.is_available() and self.flow_om: ++ feed_list = [x_in, mask_in, mu_in, t_in, spks_in, cond_in] ++ feed = [i.cpu().detach().numpy().astype(np.float32) for i in feed_list] ++ dphi_dt = self.flow_om.infer(feed, mode="dymshape", custom_sizes=10000000) ++ dphi_dt = torch.from_numpy(dphi_dt[0]).npu() ++ else: ++ dphi_dt = self.forward_estimator( ++ x_in, mask_in, ++ mu_in, t_in, ++ spks_in, ++ cond_in ++ ) + dphi_dt, cfg_dphi_dt = torch.split(dphi_dt, [x.size(0), x.size(0)], dim=0) + dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt - self.inference_cfg_rate * cfg_dphi_dt) + x = x + dt * dphi_dt +diff --git a/cosyvoice/hifigan/generator.py b/cosyvoice/hifigan/generator.py +index c47bf05..7dd9fb0 100644 +--- a/cosyvoice/hifigan/generator.py ++++ b/cosyvoice/hifigan/generator.py +@@ -20,9 +20,11 @@ from scipy.signal import get_window + import torch + import torch.nn as nn + import torch.nn.functional as F ++import torch_npu + from torch.nn import Conv1d + from torch.nn import ConvTranspose1d + from torch.nn.utils import remove_weight_norm ++from torch.nn.utils.parametrize import remove_parametrizations + from torch.nn.utils.parametrizations import weight_norm + from torch.distributions.uniform import Uniform + +@@ -99,8 +101,8 @@ class ResBlock(torch.nn.Module): + + def remove_weight_norm(self): + for idx in range(len(self.convs1)): +- remove_weight_norm(self.convs1[idx]) +- remove_weight_norm(self.convs2[idx]) ++ remove_parametrizations(self.convs1[idx], "weight") ++ remove_parametrizations(self.convs2[idx], "weight") + + + class SineGen(torch.nn.Module): +@@ -319,22 +321,19 @@ class HiFTGenerator(nn.Module): + def remove_weight_norm(self): + print('Removing weight norm...') + for l in self.ups: +- remove_weight_norm(l) ++ remove_parametrizations(l, 'weight') + for l in self.resblocks: + l.remove_weight_norm() +- remove_weight_norm(self.conv_pre) +- remove_weight_norm(self.conv_post) +- self.m_source.remove_weight_norm() +- for l in self.source_downs: +- remove_weight_norm(l) ++ remove_parametrizations(self.conv_pre, 'weight') ++ remove_parametrizations(self.conv_post, 'weight') + for l in self.source_resblocks: + l.remove_weight_norm() + + def _stft(self, x): + spec = torch.stft( +- x, +- self.istft_params["n_fft"], self.istft_params["hop_len"], self.istft_params["n_fft"], window=self.stft_window.to(x.device), +- return_complex=True) ++ x.cpu(), ++ self.istft_params["n_fft"], self.istft_params["hop_len"], self.istft_params["n_fft"], window=self.stft_window.cpu(), ++ return_complex=True).npu() + spec = torch.view_as_real(spec) # [B, F, TT, 2] + return spec[..., 0], spec[..., 1] + +@@ -342,13 +341,11 @@ class HiFTGenerator(nn.Module): + magnitude = torch.clip(magnitude, max=1e2) + real = magnitude * torch.cos(phase) + img = magnitude * torch.sin(phase) +- inverse_transform = torch.istft(torch.complex(real, img), self.istft_params["n_fft"], self.istft_params["hop_len"], +- self.istft_params["n_fft"], window=self.stft_window.to(magnitude.device)) ++ inverse_transform = torch.istft(torch.complex(real, img).cpu(), self.istft_params["n_fft"], self.istft_params["hop_len"], ++ self.istft_params["n_fft"], window=self.stft_window.cpu()).npu() + return inverse_transform + +- def decode(self, x: torch.Tensor, s: torch.Tensor = torch.zeros(1, 1, 0)) -> torch.Tensor: +- s_stft_real, s_stft_imag = self._stft(s.squeeze(1)) +- s_stft = torch.cat([s_stft_real, s_stft_imag], dim=1) ++ def decode(self, x: torch.Tensor, s_stft: torch.Tensor, index: torch.int) -> torch.Tensor: + + x = self.conv_pre(x) + for i in range(self.num_upsamples): +@@ -356,7 +353,7 @@ class HiFTGenerator(nn.Module): + x = self.ups[i](x) + + if i == self.num_upsamples - 1: +- x = self.reflection_pad(x) ++ x = torch.cat((x, x[:,:,-2:-1]), -1) + + # fusion + si = self.source_downs[i](s_stft) +@@ -373,12 +370,10 @@ class HiFTGenerator(nn.Module): + + x = F.leaky_relu(x) + x = self.conv_post(x) +- magnitude = torch.exp(x[:, :self.istft_params["n_fft"] // 2 + 1, :]) +- phase = torch.sin(x[:, self.istft_params["n_fft"] // 2 + 1:, :]) # actually, sin is redundancy ++ magnitude = torch.exp(x[:, :index, :]) ++ phase = torch.sin(x[:, index:, :]) # actually, sin is redundancy + +- x = self._istft(magnitude, phase) +- x = torch.clamp(x, -self.audio_limit, self.audio_limit) +- return x ++ return magnitude, phase + + def forward( + self, +@@ -407,5 +402,12 @@ class HiFTGenerator(nn.Module): + # use cache_source to avoid glitch + if cache_source.shape[2] != 0: + s[:, :, :cache_source.shape[2]] = cache_source +- generated_speech = self.decode(x=speech_feat, s=s) ++ # torchair编译,对decode函数做部分适配 ++ s_stft_real, s_stft_imag = self._stft(s.squeeze(1)) ++ s_stft = torch.cat([s_stft_real, s_stft_imag], dim=1) ++ # 字典取值操作无法被dynamo编译,把decode内部的index拿到外面计算 ++ index = self.istft_params["n_fft"] // 2 + 1 ++ magnitude, phase = self.decode(x=speech_feat, s_stft=s_stft, index=index) ++ x = self._istft(magnitude, phase) ++ generated_speech = torch.clamp(x, -self.audio_limit, self.audio_limit) + return generated_speech, s +diff --git a/cosyvoice/llm/llm.py b/cosyvoice/llm/llm.py +index bbd3305..1380dad 100644 +--- a/cosyvoice/llm/llm.py ++++ b/cosyvoice/llm/llm.py +@@ -11,6 +11,7 @@ + # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + # See the License for the specific language governing permissions and + # limitations under the License. ++import math + from typing import Dict, Optional, Callable, List, Generator + import torch + from torch import nn +@@ -229,16 +230,17 @@ class Qwen2Encoder(torch.nn.Module): + super().__init__() + self.model = Qwen2ForCausalLM.from_pretrained(pretrain_path) + +- def forward_one_step(self, xs, masks, cache=None): +- input_masks = masks[:, -1, :] +- outs = self.model( +- inputs_embeds=xs, +- attention_mask=input_masks, +- output_hidden_states=True, +- return_dict=True, +- use_cache=True, +- past_key_values=cache, +- ) ++ def forward_one_step(self, xs, masks, prompt_length, cache=None): ++ with torch.no_grad(): ++ outs = self.model( ++ inputs_embeds=xs, ++ attention_mask=masks, ++ prompt_length=prompt_length, ++ output_hidden_states=True, ++ return_dict=True, ++ use_cache=True, ++ past_key_values=cache, ++ ) + xs = outs.hidden_states[-1] + new_cache = outs.past_key_values + return xs, new_cache +@@ -283,6 +285,15 @@ class Qwen2LM(TransformerLM): + self.sampling = sampling + self.mix_ratio = mix_ratio + ++ # 5. added for support streaming input ++ self.prompt_speech_token_emb_dict = {} ++ self.lm_input_dict = {} ++ self.out_tokens_dict = {} ++ self.cache_dict = {} ++ self.text_cache_dict = {} ++ self.next_fill_index = {} ++ self.prompt_length = {} ++ + @torch.inference_mode() + def inference( + self, +@@ -318,9 +329,16 @@ class Qwen2LM(TransformerLM): + # 5. step by step decode + out_tokens = [] + cache = None ++ input_length = lm_input.shape[1] + for i in range(max_len): ++ prompt_length = input_length + i ++ if i == 0: ++ masks = torch.triu(torch.ones((1, 1, lm_input.shape[1], lm_input.shape[1]), device=lm_input.device), diagonal=1).to(lm_input.dtype) * -10000.0 ++ else: ++ masks = None + y_pred, cache = self.llm.forward_one_step(lm_input, +- masks=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]), device=lm_input.device)).to(torch.bool), ++ masks=masks, ++ prompt_length=prompt_length, + cache=cache) + logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1) + top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False).item() +@@ -331,7 +349,7 @@ class Qwen2LM(TransformerLM): + # in stream mode, yield token one by one + yield top_ids + out_tokens.append(top_ids) +- lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1) ++ lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1).detach().clone() + + @torch.inference_mode() + def inference_bistream( +@@ -432,3 +450,141 @@ class Qwen2LM(TransformerLM): + # in stream mode, yield token one by one + yield top_ids + lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1) ++ ++ @torch.inference_mode() ++ def inference_bistream_streaming_input( ++ self, ++ text: torch.Tensor, ++ char_idx: torch.Tensor, ++ prompt_text: torch.Tensor, ++ prompt_text_len: torch.Tensor, ++ prompt_speech_token: torch.Tensor, ++ prompt_speech_token_len: torch.Tensor, ++ embedding: torch.Tensor, ++ uuid: str, ++ input_end: bool, ++ sampling: int = 25, ++ max_token_text_ratio: float = 20, ++ min_token_text_ratio: float = 2, ++ ) -> Generator[torch.Tensor, None, None]: ++ ++ def build_causal_mask(query_len, key_len, devices, dtype): ++ assert key_len >= query_len ++ causal_mask = torch.triu(torch.ones((1, 1, query_len, key_len), device=devices), diagonal=(key_len - query_len) + 1).to(dtype) * -10000.0 ++ return causal_mask ++ ++ device = prompt_text.device ++ ++ if uuid not in self.cache_dict: ++ sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1) ++ if prompt_speech_token_len != 0: ++ self.prompt_speech_token_emb_dict[uuid] = self.speech_embedding(prompt_speech_token) ++ else: ++ self.prompt_speech_token_emb_dict[uuid] = torch.zeros(1, 0, self.llm_input_size, dtype=prompt_text.dtype).to(device) ++ ++ self.lm_input_dict[uuid] = torch.concat([sos_eos_emb], dim=1) # [1,1,896] ++ ++ self.out_tokens_dict[uuid] = [] ++ self.cache_dict[uuid] = None ++ ++ self.text_cache_dict[uuid] = self.llm.model.model.embed_tokens(prompt_text) # [1, prompt_text, 896] ++ self.next_fill_index[uuid] = -1 ++ self.prompt_length[uuid] = 0 ++ ++ text_emb = self.llm.model.model.embed_tokens(text) ++ ++ for i in range(text_emb.size(1)): ++ self.text_cache_dict[uuid] = torch.concat([self.text_cache_dict[uuid], text_emb[:, i].unsqueeze(1)], dim=1) ++ index = 0 ++ while self.prompt_speech_token_emb_dict[uuid].size(1) != 0: ++ if self.text_cache_dict[uuid].size(1) >= self.mix_ratio[0]: ++ lm_input_text, lm_input_speech = self.text_cache_dict[uuid][:, :self.mix_ratio[0]], self.prompt_speech_token_emb_dict[uuid][:, :self.mix_ratio[1]] ++ index += 1 ++ logging.info('append {} text token {} speech token'.format(lm_input_text.size(1), lm_input_speech.size(1))) ++ self.lm_input_dict[uuid] = torch.concat([self.lm_input_dict[uuid], lm_input_text, lm_input_speech], dim=1) ++ self.text_cache_dict[uuid], self.prompt_speech_token_emb_dict[uuid] = self.text_cache_dict[uuid][:, self.mix_ratio[0]:], self.prompt_speech_token_emb_dict[uuid][:, self.mix_ratio[1]:] ++ else: ++ break ++ ++ if self.prompt_speech_token_emb_dict[uuid].size(1) == 0: # 文本token数量多于音频token,混合完以后,剩余文本token,开始解码 ++ # 若上一次解码的 token 是 fill_token,说明 LLM 想要更多 text token ++ # 或者首次预测时,还没开始解码,out_tokens_dict 为空 ++ if ((len(self.out_tokens_dict[uuid]) != 0 and self.out_tokens_dict[uuid][-1] == self.speech_token_size + 2) ++ or (len(self.out_tokens_dict[uuid]) == 0 and self.lm_input_dict[uuid].size(1) == 1)): ++ # token数量够了 ++ if self.text_cache_dict[uuid].size(1) >= self.mix_ratio[0]: ++ lm_input_text = self.text_cache_dict[uuid][:, :self.mix_ratio[0]] # 抽出5个token ++ if len(self.out_tokens_dict[uuid]) != 0 and self.out_tokens_dict[uuid][-1] == self.speech_token_size + 2: # 预测出filling token,前面cache已经缓存,当前直接输入即可 ++ self.lm_input_dict[uuid] = lm_input_text ++ else: # sft刚开始预测,需要和sos token拼接在一起 ++ self.lm_input_dict[uuid] = torch.concat([self.lm_input_dict[uuid], lm_input_text], dim=1) ++ self.text_cache_dict[uuid] = self.text_cache_dict[uuid][:, self.mix_ratio[0]:] ++ else: ++ continue ++ ++ while True: ++ self.prompt_length[uuid] += self.lm_input_dict[uuid].shape[1] ++ seq_len = self.prompt_length[uuid] ++ if self.lm_input_dict[uuid].shape[1] > 1: ++ masks = build_causal_mask(self.lm_input_dict[uuid].shape[1], seq_len, ++ self.lm_input_dict[uuid].device, self.lm_input_dict[uuid].dtype) ++ else: ++ masks = None ++ y_pred, self.cache_dict[uuid] = self.llm.forward_one_step(self.lm_input_dict[uuid], ++ masks=masks, ++ prompt_length=seq_len, ++ cache=self.cache_dict[uuid]) ++ logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1) ++ # 判断是否生成 filling_token: ++ if self.next_fill_index[uuid] != -1 and len(self.out_tokens_dict[uuid]) == self.next_fill_index[uuid]: ++ top_ids = self.speech_token_size + 2 # 该预测filling token了 ++ self.next_fill_index[uuid] += (self.mix_ratio[1] + 1) # 找到下一个filling token的位置 ++ else: ++ top_ids = self.sampling_ids(logp.squeeze(dim=0), self.out_tokens_dict[uuid], sampling, ignore_eos=True).item() ++ # 特殊 token 处理, fill_token → 中断预测、等待新文本 token。 ++ if top_ids == self.speech_token_size + 2: ++ self.next_fill_index[uuid] = len(self.out_tokens_dict[uuid]) + self.mix_ratio[1] + 1 # -1 > 30 ++ self.out_tokens_dict[uuid].append(top_ids) ++ if top_ids >= self.speech_token_size: ++ if top_ids == self.speech_token_size + 2: # 预测到了filling token, break掉迎接新的文本token ++ break ++ else: ++ raise ValueError('should not get token {}'.format(top_ids)) ++ yield top_ids ++ self.lm_input_dict[uuid] = self.speech_embedding.weight[top_ids].reshape(1, 1, -1).detach().clone() ++ ++ if input_end: ++ # 3. final decode 文本全部送完,进行最后的解码。 ++ task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1) ++ self.lm_input_dict[uuid] = torch.concat([self.lm_input_dict[uuid], self.text_cache_dict[uuid], task_id_emb, self.prompt_speech_token_emb_dict[uuid]], dim=1) ++ logging.info('no more text token, decode until met eos') ++ while True: ++ self.prompt_length[uuid] += self.lm_input_dict[uuid].shape[1] ++ seq_len = self.prompt_length[uuid] ++ if self.lm_input_dict[uuid].shape[1] > 1: ++ masks = build_causal_mask(self.lm_input_dict[uuid].shape[1], seq_len, self.lm_input_dict[uuid].device, self.lm_input_dict[uuid].dtype) ++ else: ++ masks = None ++ y_pred, self.cache_dict[uuid] = self.llm.forward_one_step(self.lm_input_dict[uuid], ++ masks=masks, ++ prompt_length=seq_len, ++ cache=self.cache_dict[uuid]) ++ logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1) ++ top_ids = self.sampling_ids(logp.squeeze(dim=0), self.out_tokens_dict[uuid], sampling, ignore_eos=False).item() ++ self.out_tokens_dict[uuid].append(top_ids) ++ if top_ids >= self.speech_token_size: ++ if top_ids == self.speech_token_size: ++ break ++ else: ++ raise ValueError('should not get token {}'.format(top_ids)) ++ # in stream mode, yield token one by one ++ yield top_ids ++ self.lm_input_dict[uuid] = self.speech_embedding.weight[top_ids].reshape(1, 1, -1).detach().clone() ++ ++ # this user is done ++ self.prompt_speech_token_emb_dict.pop(uuid) ++ self.lm_input_dict.pop(uuid) ++ self.out_tokens_dict.pop(uuid) ++ self.cache_dict.pop(uuid) ++ self.text_cache_dict.pop(uuid) ++ self.next_fill_index.pop(uuid) +\ No newline at end of file +diff --git a/cosyvoice/utils/common.py b/cosyvoice/utils/common.py +index 3e61a8c..d316b92 100644 +--- a/cosyvoice/utils/common.py ++++ b/cosyvoice/utils/common.py +@@ -107,12 +107,33 @@ def init_weights(m, mean=0.0, std=0.01): + + # Repetition Aware Sampling in VALL-E 2 + def ras_sampling(weighted_scores, decoded_tokens, sampling, top_p=0.8, top_k=25, win_size=10, tau_r=0.1): +- top_ids = nucleus_sampling(weighted_scores, top_p=top_p, top_k=top_k) ++ top_ids = dst_sampling(weighted_scores, top_p=top_p, top_k=top_k) + rep_num = (torch.tensor(decoded_tokens[-win_size:]).to(weighted_scores.device) == top_ids).sum().item() + if rep_num >= win_size * tau_r: + top_ids = random_sampling(weighted_scores, decoded_tokens, sampling) + return top_ids + ++def dst_sampling(weighted_scores, top_p=0.8, top_k=25): ++ ++ sorted_value, sorted_idx = weighted_scores.softmax(dim=0).sort(descending=True, stable=True) ++ ++ cum_sum = torch.cumsum(sorted_value, dim=0) ++ n = sorted_value.size(0) ++ device = cum_sum.device ++ pre_cum_sum = torch.cat([torch.zeros(1, device=device), cum_sum[:-1]]) ++ ++ indices = torch.arange(n ,device=device) ++ condition = (pre_cum_sum < top_p) & (indices < top_k) ++ ++ max_i_tensor = torch.where(condition, indices, torch.tensor(-1, device=device)) ++ n_selected = max_i_tensor.max() + 1 ++ ++ selected_prob = sorted_value[:n_selected] ++ selected_indices = sorted_idx[:n_selected] ++ ++ top_ids = selected_indices[selected_prob.multinomial(1, replacement=True)] ++ ++ return top_ids + + def nucleus_sampling(weighted_scores, top_p=0.8, top_k=25): + prob, indices = [], [] diff --git a/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/diff_CosyVoice.patch b/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/800I/diff_CosyVoice_800I.patch old mode 100755 new mode 100644 similarity index 100% rename from ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/diff_CosyVoice.patch rename to ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/800I/diff_CosyVoice_800I.patch diff --git a/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/modeling_qwen2.py b/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/800I/modeling_qwen2.py old mode 100755 new mode 100644 similarity index 100% rename from ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/modeling_qwen2.py rename to ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/800I/modeling_qwen2.py diff --git a/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/README.md b/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/README.md index 661d8377a39935d96a1e5d01e9ef8f6c4d9caa31..b418438adf65815cf58a74e63b988a7a321d1e87 100755 --- a/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/README.md +++ b/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/README.md @@ -55,28 +55,36 @@ cd ModelZoo-PyTorch/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2 cd CosyVoice git reset --hard fd45708 git submodule update --init --recursive - git apply ../diff_CosyVoice.patch + git apply ../${platform}/diff_CosyVoice_${platform}.patch + # 将infer.py复制到CosyVoice中 + cp ../infer.py ./ # 获取Transformer源码 + cd .. git clone https://github.com/huggingface/transformers.git cd transformers git checkout v4.37.0 cd .. # 将modeling_qwen模型文件替换到transformers仓内 - mv ../modeling_qwen2.py ./transformers/src/transformers/models/qwen2 + mv ../${platform}/modeling_qwen2.py ./transformers/src/transformers/models/qwen2 ``` 文件目录结构大致如下: ```text 📁 CosyVoice/ ├── 📁 CosyVoice2/ - | |── 📄 diff_CosyVoice.patch - | |── 📄 modeling_qwen2.py + | |── 📁 300I + | |── 📄 diff_CosyVoice_300I.patch + | |── 📄 modeling_qwen2.py + | |── 📁 800I + | |── 📄 diff_CosyVoice_800I.patch + | |── 📄 modeling_qwen2.py | |── 📁 CosyVoice | |── 📁 cosyVoice源码文件 # cosyVoice的源码文件,此处不一一列举 - │ ├── 📁 CosyVoice-0.5B/ # 权重文件 - │ ├── 📁 transformers/ # transformers文件,里面有修改过的modeling_qwen2.py文件 - │ ├── 📄 infer.py # 推理脚本 - │ └── 📄 modify_onnx.py # 模型转换脚本 + │ ├── 📁 CosyVoice-0.5B/ # 权重文件 + │ ├── 📁 transformers/ # transformers库,里面修改modeling_qwen2.py文件 + │── 📄 requirements.txt # 依赖库 + |── 📄 infer.py # 推理脚本 + └── 📄 modify_onnx.py # 模型转换脚本 ``` 2. 安装依赖 @@ -103,7 +111,7 @@ cd ModelZoo-PyTorch/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2 3. 安装msit工具 - 参考[msit](https://gitee.com/ascend/msit)安装工具中的benchmark和surgen组件。(未安装会提示 ais_bench 导入失败报错) + 参考[msit](https://gitee.com/ascend/msit)安装工具中的benchmark和surgeon组件。(未安装会提示 ais_bench 导入失败报错) 4. 获取权重数据 @@ -153,7 +161,7 @@ cd ModelZoo-PyTorch/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2 执行ATC命令,将利用npu-smi info命令获取的芯片型号填入${soc_version}中 ``` - atc --framework=5 --soc_version=${soc_version} --model ./${CosyVoice2-0.5B}/speech_token_md.onnx --output ./${CosyVoice2-0.5B}/speech --input_shape="feats:1,128,-1;feats_length:1" + atc --framework=5 --soc_version=${soc_version} --model ./${CosyVoice2-0.5B}/speech_token_md.onnx --output ./${CosyVoice2-0.5B}/speech --input_shape="feats:1,128,-1;feats_length:1" --precision_mode allow_fp32_to_fp16 atc --framework=5 --soc_version=${soc_version} --model ./${CosyVoice2-0.5B}/flow.decoder.estimator.fp32.onnx --output ./${CosyVoice2-0.5B}/flow --input_shape="x:2,80,-1;mask:2,1,-1;mu:2,80,-1;t:2;spks:2,80;cond:2,80,-1" atc --framework=5 --soc_version=${soc_version} --model ./${CosyVoice2-0.5B}/flow.decoder.estimator.fp32.onnx --output ./${CosyVoice2-0.5B}/flow_static --input_shape="x:2,80,-1;mask:2,1,-1;mu:2,80,-1;t:2;spks:2,80;cond:2,80,-1" --dynamic_dims="100,100,100,100;200,200,200,200;300,300,300,300;400,400,400,400;500,500,500,500;600,600,600,600;700,700,700,700" --input_format=ND ``` @@ -163,10 +171,7 @@ cd ModelZoo-PyTorch/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2 ### 2 开始推理验证 - 1. 首先移动infer.py文件到CosyVoice目录下 - - - 2. 设置环境变量,执行推理命令 + 1. 设置环境变量,执行推理命令 ``` # 1. 指定使用NPU ID,默认为0 @@ -187,9 +192,12 @@ cd ModelZoo-PyTorch/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2 * 非流式输入:将推理结果保存在`sft_i.wav`中,并打屏性能数据:实时率(rtf),指的是平均1s时长的音频需要多少时间处理。 * 流式输入:将推理结果保存在`stream_input_out_i.wav`文件中,并打屏性能数据:实时率(rtf) + 3. 如因为意外操作导致torchair编译失败,需将已生成的.torchair_cache路径删除,避免使用编译错误的图导致的前向出错。 + ### 3 性能数据 | 模型 |芯片|rtf(实时率)| |-----------|------|------| - | cosyvoice |800I A2|0.28s| + | cosyvoice |800I A2|0.28| + | cosyvoice |300I DUO|0.75| diff --git a/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/infer.py b/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/infer.py old mode 100755 new mode 100644 index 972cb18522597bc8ff27d593e31658820325815a..3f17564040ae2805b9835916e9b541ecd7f5516e --- a/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/infer.py +++ b/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/infer.py @@ -9,7 +9,7 @@ # See the Mulan PSL v2 for more details. import argparse -from tqdm import tqdm +import time import torch import torchaudio import torch_npu @@ -17,7 +17,6 @@ from torch_npu.contrib import transfer_to_npu import torchair as tng from torchair.configs.compiler_config import CompilerConfig from cosyvoice.cli.cosyvoice import CosyVoice2 -from cosyvoice.utils.file_utils import load_wav def no_stream_input_inference(args, cosyvoice, prompt_txt): @@ -27,9 +26,19 @@ def no_stream_input_inference(args, cosyvoice, prompt_txt): for _ in enumerate(cosyvoice.inference_sft(prompt_txt[0], '中文女', stream=args.stream_out)): pass print('warm up end') - for _ in range(args.infer_count): - for i, j in enumerate(cosyvoice.inference_sft(prompt_txt[0], '中文女', stream=args.stream_out)): - torchaudio.save('sft_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate) + infer_res = [torch.tensor([]) for _ in range(args.infer_count)] + rtf = [] + for i_step in range(args.infer_count): + start_time = time.time() + for _, j in enumerate(cosyvoice.inference_sft(prompt_txt[0], '中文女', stream=args.stream_out)): + infer_res[i_step] = torch.cat((infer_res[i_step], j['tts_speech']), dim=1) + end_time = time.time() + speech_len = infer_res[i_step].shape[1] / cosyvoice.sample_rate + print(f"singe infer RTF: {(end_time - start_time) / speech_len}") + rtf.append((end_time - start_time) / speech_len) + print(f"save out wav file to sft_out_{i_step+1}.wav") + torchaudio.save(f"sft_out_{i_step+1}.wav", infer_res[i_step], cosyvoice.sample_rate) + print(f"avg RTF: {sum(rtf) / len(rtf)}") def stream_input_inference(args, cosyvoice, prompt_txt): @@ -44,13 +53,13 @@ def stream_input_inference(args, cosyvoice, prompt_txt): if mode == "warmup": pass else: - infer_res[i_step] = torch.cat((infer_res[i_step], j['tts_speech']), dim=1) + infer_res[step] = torch.cat((infer_res[step], j['tts_speech']), dim=1) else: for _, j in enumerate(cosyvoice.inference_sft_streaming_input(char, char_idx, "中文女", user_id="AscendUser", input_end=False, stream=args.stream_out)): if mode == "warmup": pass else: - infer_res[i_step] = torch.cat((infer_res[i_step], j['tts_speech']), dim=1) + infer_res[step] = torch.cat((infer_res[step], j['tts_speech']), dim=1) infer_res = [torch.tensor([]) for _ in range(args.infer_count)] @@ -61,13 +70,19 @@ def stream_input_inference(args, cosyvoice, prompt_txt): print("warm up end") print("inference start") + rtf = [] for i_step in range(args.infer_count): + start_time = time.time() inference_step(i_step, mode="inference") + end_time = time.time() + speech_len = infer_res[i_step].shape[1] / cosyvoice.sample_rate + print(f"avg RTF: {(end_time - start_time) / speech_len}") + rtf.append((end_time - start_time) / speech_len) + print(f"save out wav file to stream_input_out_{i_step+1}.wav") + torchaudio.save(f"stream_input_out_{i_step+1}.wav", infer_res[i_step], cosyvoice.sample_rate) + print(f"avg RTF: {sum(rtf) / len(rtf)}") print("inference end") - print(f"save out wav file ...") - for i_step in tqdm(range(args.infer_count)): - torchaudio.save(f"stream_input_out_{i_step+1}.wav", infer_res[i_step], 24000) if __name__ == '__main__': torch_npu.npu.set_compile_mode(jit_compile=False) diff --git a/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/requirements.txt b/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/requirements.txt index c3cdf647ef6cb4a52b6e650321c2000d4bbfe0ed..7afcddafb269269a08c4b19c0f1ef3059c09bd22 100755 --- a/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/requirements.txt +++ b/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/requirements.txt @@ -20,6 +20,7 @@ onnxruntime==1.16.0 openai-whisper==20231117 protobuf==4.25 pydantic==2.7.0 +pyworld==0.3.4 rich==13.7.1 soundfile==0.12.1 tensorboard==2.14.0