diff --git a/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/README.md b/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/README.md new file mode 100755 index 0000000000000000000000000000000000000000..2bc13c07d83d50a74e021fd5e7b08fa29308ec20 --- /dev/null +++ b/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/README.md @@ -0,0 +1,155 @@ +# CosyVoice-推理指导 + +- [CosyVoice-推理指导](#cosyvoice-推理指导) +- [概述](#概述) +- [推理环境准备](#推理环境准备) +- [快速上手](#快速上手) + - [获取源码](#获取源码) + - [模型推理](#模型推理) + - [1 模型转换](#1-模型转换) + - [2 开始推理验证](#2-开始推理验证) + - [3 性能数据](#3-性能) + +****** + +# 概述 +  ‌Co‌syVoice是一款基于语音量化编码的语音生成大模型,能够深度融合文本理解和语音生成,实现自然流畅的语音体验。它通过离散化编码和依托大模型技术,能够精准解析并诠释各类文本内容,将其转化为宛如真人般的自然语音‌。CosyVoice2在原始1的基础上,把QWEN2模型接入CosyVoice的LLM部分,实现了推理加速 + +- 版本说明: + ``` + url=https://github.com/FunAudioLLM/CosyVoice + commit_id=fd45708 + model_name=Cosyvoice + ``` + +# 推理环境准备 +- 该模型需要以下插件与驱动 + **表 1** 版本配套表 + + | 配套 | 版本 | 环境准备指导 | + | ------------------------------------------------------------ | ------ | ------------------------------------------------------------ | + | 固件与驱动 | 24.0.RC3 | [Pytorch框架推理环境准备](https://www.hiascend.com/document/detail/zh/ModelZoo/pytorchframework/pies) | + | CANN | 8.0.RC3 | 包含kernels包和toolkit包 | + | Python | 3.8 | - | + | PyTorch | 2.4.0 | - | + | Ascend Extension PyTorch | 2.4.0.post2 | - | + | 说明:Atlas 800I A2 推理卡和Atlas 300I DUO 推理卡请以CANN版本选择实际固件与驱动版本。 | \ | \ | + + +# 快速上手 + +## 获取源码 + +1. 获取`PyTorch`源码 + ``` + # 获取CosyVoice源码 + git clone https://github.com/FunAudioLLM/CosyVoice + cd CosyVoice + git reset --hard fd45708 + git submodule update --init --recursive + git apply ../diff_CosyVoice.patch + # 获取Transformer源码 + git clone https://github.com/huggingface/transformers.git + cd transformers + git checkout v4.37.0 + cd .. + # 将modeling_qwen模型文件替换到transformers仓内 + mv ../modeling_qwen2.py ./transformers/src/transformers/models/qwen2 + ``` + +2. 安装依赖 + ``` + pip3 install -r ../requirements.txt + apt-get install sox # centos版本 yum install sox + ``` + 注:如果遇到无法安装WeTextProcessing的场景,可以参考以下方法手动安装编译 + ```bash + # 下载安装包并解压 + wget https://www.openfst.org/twiki/pub/FST/FstDownload/openfst-1.8.3.tar.gz + # 进入目录后编译安装 + ./configure --enable-far --enable-mpdt --enable-pdt + make install + pip3 install WeTextProcessing==1.0.4.1 + ``` + +3. 安装msit工具 + + 参考[msit](https://gitee.com/ascend/msit)安装工具中的benchmark和surgen组件。 + + +4. 获取权重数据 + + 本案例以CosyVoice2-0.5B为例,其他权重请自行适配 + + 获取 https://www.modelscope.cn/iic/CosyVoice2-0.5B 权重文件夹,放在CosyVoice目录下 + + 或者通过git方式获取 + ``` + # git模型下载,请确保已安装git lfs + git clone https://www.modelscope.cn/iic/CosyVoice2-0.5B.git CosyVoice/CosyVoice2-0.5B + ``` + 本用例采用sft预训练音色推理,请额外下载spk权重放到权重目录下 + ``` + wget https://www.modelscope.cn/models/iic/CosyVoice-300M-SFT/resolve/master/spk2info.pt + ``` + +## 模型推理 + +### 1 模型转换 + +   模型权重中提供了 flow.decoder.estimator.fp32.onnx和speech_tokenizer_v2.onnx两个onnx模型,对其进行结构修改后使用`ATC`工具将`.onnx`文件转为离线推理模型`.om`文件。 + +1. 修改onnx模型结构 + + ``` + python3 modify_onnx.py ${CosyVoice2-0.5B} + ``` + + CosyVoice-300M是onnx模型所在权重文件夹,其他权重请自行更改权重文件名。执行该命令后会在CosyVoice-300M目录下生成修改后的onnx文件speech_token_md.onnx + +2. 使用`ATC`工具将`ONNX`模型转为`OM`模型 + + 配置环境变量 + + ``` + source /usr/local/Ascend/ascend-toolkit/set_env.sh + ``` + + 执行ATC命令,将利用npu-smi info命令获取的芯片型号填入${soc_version}中 + + ``` + atc --framework=5 --soc_version=${soc_version} --model ./${CosyVoice2-0.5B}/speech_token_md.onnx --output ./${CosyVoice2-0.5B}/speech --input_shape="feats:1,128,-1;feats_length:1" + atc --framework=5 --soc_version=${soc_version} --model ./${CosyVoice2-0.5B}/flow.decoder.estimator.fp32.onnx --output ./${CosyVoice2-0.5B}/flow --input_shape="x:2,80,-1;mask:2,1,-1;mu:2,80,-1;t:2;spks:2,80;cond:2,80,-1" + atc --framework=5 --soc_version=${soc_version} --model ./${CosyVoice2-0.5B}/flow.decoder.estimator.fp32.onnx --output ./${CosyVoice2-0.5B}/flow_static --input_shape="x:2,80,-1;mask:2,1,-1;mu:2,80,-1;t:2;spks:2,80;cond:2,80,-1" --dynamic_dims="100,100,100,100;200,200,200,200;300,300,300,300;400,400,400,400;500,500,500,500;600,600,600,600;700,700,700,700" + ``` + 在权重目录CosyVoice2-0.5B下会生成三个om模型, 分别为 speech_{arch}.om和flow_{arch}.om,flow_static.om。其中flow_static.om为分档模型,在流式输出中生效,档位设置为模型中默认流式输出token档位,如果在模型中修改token_hope_len,档位也需要对应修改。 + + 注:模型{arch}后缀为当前使用的CPU操作系统。 + +### 2 开始推理验证 + + 1. 首先移动infer.py文件到CosyVoice目录下 + + + 2. 设置环境变量,执行推理命令 + + ``` + # 指定使用NPU ID,默认为0 + export ASCEND_RT_VISIBLE_DEVICES=0 + export PYTHONPATH=third_party/Matcha-TTS:$PYTHONPATH + export PYTHONPATH=transformers/src:$PYTHONPATH + python3 infer.py --model_path=${CosyVoice2-0.5B} --stream + ``` + - --model_path: 权重路径 + - --warm_up_times:warm up次数,默认为2 + - --infer_count:循环推理次数,默认为20 + - --stream:是否执行流式推理 + + 在推理开始后,首先会默认执行warm_up,目的是执行首次编译,首次编译时间较长,首次编译结束后,会在当前目录下生成.torchair_cache文件,后续推理无需重复编译,在warm_up结束后,会执行推理操作,并将推理结果保存在'sft_i.wav'中,并打屏性能数据:实时率(rtf),指的是平均1s时长的音频需要多少时间处理。 + +### 3 性能数据 + + |模型|芯片|rtf(实时率)| + |------|------|------| + |cosyvoice|800I A2|0.28s| + diff --git a/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/diff_CosyVoice.patch b/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/diff_CosyVoice.patch new file mode 100755 index 0000000000000000000000000000000000000000..d9f3153693464bfe445573ed0a195fe9e8240739 --- /dev/null +++ b/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/diff_CosyVoice.patch @@ -0,0 +1,375 @@ +diff --git a/cosyvoice/cli/cosyvoice.py b/cosyvoice/cli/cosyvoice.py +index e2d62e2..99f8463 100644 +--- a/cosyvoice/cli/cosyvoice.py ++++ b/cosyvoice/cli/cosyvoice.py +@@ -13,11 +13,13 @@ + # limitations under the License. + import os + import time ++import platform + from typing import Generator + from tqdm import tqdm + from hyperpyyaml import load_hyperpyyaml + from modelscope import snapshot_download + import torch ++from ais_bench.infer.interface import InferSession + from cosyvoice.cli.frontend import CosyVoiceFrontEnd + from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model + from cosyvoice.utils.file_utils import logging +@@ -126,7 +128,7 @@ class CosyVoice: + + class CosyVoice2(CosyVoice): + +- def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False): ++ def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False, load_om=False): + self.instruct = True if '-Instruct' in model_dir else False + self.model_dir = model_dir + self.fp16 = fp16 +@@ -155,6 +157,16 @@ class CosyVoice2(CosyVoice): + self.model.load_trt('{}/flow.decoder.estimator.{}.mygpu.plan'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'), + '{}/flow.decoder.estimator.fp32.onnx'.format(model_dir), + self.fp16) ++ if load_om: ++ arch = platform.machine() ++ system = platform.system().lower() ++ flow_om = InferSession(0, '{}/flow_{}_{}.om'.format(model_dir, system ,arch)) ++ flow_om_static = InferSession(0, '{}/flow_static.om'.format(model_dir)) ++ speech_om = InferSession(0, '{}/speech_{}_{}.om'.format(model_dir, system ,arch)) ++ self.frontend.speech_om = speech_om ++ self.frontend.flow_om = flow_om ++ self.model.flow.decoder.flow_om_static = flow_om_static ++ self.model.flow.decoder.flow_om = flow_om + del configs + + def inference_instruct(self, *args, **kwargs): +diff --git a/cosyvoice/cli/frontend.py b/cosyvoice/cli/frontend.py +index 6e10f00..0eb45a8 100644 +--- a/cosyvoice/cli/frontend.py ++++ b/cosyvoice/cli/frontend.py +@@ -71,6 +71,8 @@ class CosyVoiceFrontEnd: + self.zh_tn_model = ZhNormalizer(remove_erhua=False, full_to_half=False, overwrite_cache=True) + self.en_tn_model = EnNormalizer() + self.inflect_parser = inflect.engine() ++ self.speech_om = None ++ self.flow_om = None + + def _extract_text_token(self, text): + if isinstance(text, Generator): +@@ -92,11 +94,16 @@ class CosyVoiceFrontEnd: + def _extract_speech_token(self, speech): + assert speech.shape[1] / 16000 <= 30, 'do not support extract speech token for audio longer than 30s' + feat = whisper.log_mel_spectrogram(speech, n_mels=128) +- speech_token = self.speech_tokenizer_session.run(None, +- {self.speech_tokenizer_session.get_inputs()[0].name: +- feat.detach().cpu().numpy(), +- self.speech_tokenizer_session.get_inputs()[1].name: +- np.array([feat.shape[2]], dtype=np.int32)})[0].flatten().tolist() ++ if torch.npu.is_available() and self.speech_om: ++ feed = [feat.detach().cpu().numpy(), np.array([feat.shape[2]], dtype=np.int32)] ++ speech_token = self.speech_om.infer(feed, mode='dymshape', custom_sizes=[100000000])[0].flatten().tolist() ++ self.flow_om.set_context() ++ else: ++ speech_token = self.speech_tokenizer_session.run(None, ++ {self.speech_tokenizer_session.get_inputs()[0].name: ++ feat.detach().cpu().numpy(), ++ self.speech_tokenizer_session.get_inputs()[1].name: ++ np.array([feat.shape[2]], dtype=np.int32)})[0].flatten().tolist() + speech_token = torch.tensor([speech_token], dtype=torch.int32).to(self.device) + speech_token_len = torch.tensor([speech_token.shape[1]], dtype=torch.int32).to(self.device) + return speech_token, speech_token_len +diff --git a/cosyvoice/cli/model.py b/cosyvoice/cli/model.py +index 9ebf8cb..af31dc9 100644 +--- a/cosyvoice/cli/model.py ++++ b/cosyvoice/cli/model.py +@@ -362,13 +362,18 @@ class CosyVoice2Model(CosyVoiceModel): + with self.lock: + self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False + self.hift_cache_dict[this_uuid] = None +- p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid)) +- p.start() + if stream is True: + token_offset = 0 +- while True: +- time.sleep(0.1) +- if len(self.tts_speech_token_dict[this_uuid]) - token_offset >= self.token_hop_len + self.flow.pre_lookahead_len: ++ # 删除线程操作,串行执行推理,加速首包时延 ++ for i in self.llm.inference(text=text.to(self.device), ++ text_len=torch.tensor([text.shape[1]], dtype=torch.int32).to(self.device), ++ prompt_text=prompt_text.to(self.device), ++ prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device), ++ prompt_speech_token=llm_prompt_speech_token.to(self.device), ++ prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device), ++ embedding=llm_embedding.to(self.device)): ++ self.tts_speech_token_dict[this_uuid].append(i) ++ if len(self.tts_speech_token_dict[this_uuid]) - token_offset >= self.token_hop_len + self.flow.pre_lookahead_len: + this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:token_offset + self.token_hop_len + self.flow.pre_lookahead_len]).unsqueeze(dim=0) + this_tts_speech = self.token2wav(token=this_tts_speech_token, + prompt_token=flow_prompt_speech_token, +@@ -379,10 +384,6 @@ class CosyVoice2Model(CosyVoiceModel): + finalize=False) + token_offset += self.token_hop_len + yield {'tts_speech': this_tts_speech.cpu()} +- if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) - token_offset < self.token_hop_len + self.flow.pre_lookahead_len: +- break +- p.join() +- # deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None + this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0) + this_tts_speech = self.token2wav(token=this_tts_speech_token, + prompt_token=flow_prompt_speech_token, +@@ -393,6 +394,8 @@ class CosyVoice2Model(CosyVoiceModel): + finalize=True) + yield {'tts_speech': this_tts_speech.cpu()} + else: ++ p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid)) ++ p.start() + # deal with all tokens + p.join() + this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0) +diff --git a/cosyvoice/flow/flow_matching.py b/cosyvoice/flow/flow_matching.py +index 6a60f6d..fbe7545 100644 +--- a/cosyvoice/flow/flow_matching.py ++++ b/cosyvoice/flow/flow_matching.py +@@ -14,6 +14,7 @@ + import threading + import torch + import torch.nn.functional as F ++import numpy as np + from matcha.models.components.flow_matching import BASECFM + + +@@ -32,6 +33,8 @@ class ConditionalCFM(BASECFM): + # Just change the architecture of the estimator here + self.estimator = estimator + self.lock = threading.Lock() ++ self.flow_om = None ++ self.flow_om_static = None + + @torch.inference_mode() + def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, prompt_len=0, flow_cache=torch.zeros(1, 80, 0, 2)): +@@ -105,12 +108,26 @@ class ConditionalCFM(BASECFM): + t_in[:] = t.unsqueeze(0) + spks_in[0] = spks + cond_in[0] = cond +- dphi_dt = self.forward_estimator( +- x_in, mask_in, +- mu_in, t_in, +- spks_in, +- cond_in +- ) ++ # 动态分档推理, 在流式输出中,每次输出的token数目固定,可以采取动态分档模型执行推理 ++ if torch.npu.is_available() and self.flow_om_static and x.size(2)%100==0 and x.size(2)<800: ++ feed_list = [x_in, mask_in, mu_in, t_in, spks_in, cond_in] ++ feed = [i.cpu().detach().numpy().astype(np.float32) for i in feed_list] ++ dphi_dt = self.flow_om_static.infer(feed, mode="dymdims") ++ self.flow_om.set_context() ++ dphi_dt = torch.from_numpy(dphi_dt[0]).npu() ++ # 输出的token数目不固定场景采用动态模型推理 ++ elif torch.npu.is_available() and self.flow_om: ++ feed_list = [x_in, mask_in, mu_in, t_in, spks_in, cond_in] ++ feed = [i.cpu().detach().numpy().astype(np.float32) for i in feed_list] ++ dphi_dt = self.flow_om.infer(feed, mode="dymshape", custom_sizes=10000000) ++ dphi_dt = torch.from_numpy(dphi_dt[0]).npu() ++ else: ++ dphi_dt = self.forward_estimator( ++ x_in, mask_in, ++ mu_in, t_in, ++ spks_in, ++ cond_in ++ ) + dphi_dt, cfg_dphi_dt = torch.split(dphi_dt, [x.size(0), x.size(0)], dim=0) + dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt - self.inference_cfg_rate * cfg_dphi_dt) + x = x + dt * dphi_dt +diff --git a/cosyvoice/hifigan/generator.py b/cosyvoice/hifigan/generator.py +index c47bf05..a7c8b37 100644 +--- a/cosyvoice/hifigan/generator.py ++++ b/cosyvoice/hifigan/generator.py +@@ -23,6 +23,7 @@ import torch.nn.functional as F + from torch.nn import Conv1d + from torch.nn import ConvTranspose1d + from torch.nn.utils import remove_weight_norm ++from torch.nn.utils.parametrize import remove_parametrizations + from torch.nn.utils.parametrizations import weight_norm + from torch.distributions.uniform import Uniform + +@@ -99,8 +100,8 @@ class ResBlock(torch.nn.Module): + + def remove_weight_norm(self): + for idx in range(len(self.convs1)): +- remove_weight_norm(self.convs1[idx]) +- remove_weight_norm(self.convs2[idx]) ++ remove_parametrizations(self.convs1[idx], "weight") ++ remove_parametrizations(self.convs2[idx], "weight") + + + class SineGen(torch.nn.Module): +@@ -319,14 +320,11 @@ class HiFTGenerator(nn.Module): + def remove_weight_norm(self): + print('Removing weight norm...') + for l in self.ups: +- remove_weight_norm(l) ++ remove_parametrizations(l, 'weight') + for l in self.resblocks: + l.remove_weight_norm() +- remove_weight_norm(self.conv_pre) +- remove_weight_norm(self.conv_post) +- self.m_source.remove_weight_norm() +- for l in self.source_downs: +- remove_weight_norm(l) ++ remove_parametrizations(self.conv_pre, 'weight') ++ remove_parametrizations(self.conv_post, 'weight') + for l in self.source_resblocks: + l.remove_weight_norm() + +@@ -346,9 +344,7 @@ class HiFTGenerator(nn.Module): + self.istft_params["n_fft"], window=self.stft_window.to(magnitude.device)) + return inverse_transform + +- def decode(self, x: torch.Tensor, s: torch.Tensor = torch.zeros(1, 1, 0)) -> torch.Tensor: +- s_stft_real, s_stft_imag = self._stft(s.squeeze(1)) +- s_stft = torch.cat([s_stft_real, s_stft_imag], dim=1) ++ def decode(self, x: torch.Tensor, s_stft: torch.Tensor, index: torch.int) -> torch.Tensor: + + x = self.conv_pre(x) + for i in range(self.num_upsamples): +@@ -356,7 +352,7 @@ class HiFTGenerator(nn.Module): + x = self.ups[i](x) + + if i == self.num_upsamples - 1: +- x = self.reflection_pad(x) ++ x = torch.cat((x, x[:,:,-2:-1]), -1) + + # fusion + si = self.source_downs[i](s_stft) +@@ -373,12 +369,10 @@ class HiFTGenerator(nn.Module): + + x = F.leaky_relu(x) + x = self.conv_post(x) +- magnitude = torch.exp(x[:, :self.istft_params["n_fft"] // 2 + 1, :]) +- phase = torch.sin(x[:, self.istft_params["n_fft"] // 2 + 1:, :]) # actually, sin is redundancy ++ magnitude = torch.exp(x[:, :index, :]) ++ phase = torch.sin(x[:, index:, :]) # actually, sin is redundancy + +- x = self._istft(magnitude, phase) +- x = torch.clamp(x, -self.audio_limit, self.audio_limit) +- return x ++ return magnitude, phase + + def forward( + self, +@@ -407,5 +401,12 @@ class HiFTGenerator(nn.Module): + # use cache_source to avoid glitch + if cache_source.shape[2] != 0: + s[:, :, :cache_source.shape[2]] = cache_source +- generated_speech = self.decode(x=speech_feat, s=s) ++ # torchair编译,对decode函数做部分适配 ++ s_stft_real, s_stft_imag = self._stft(s.squeeze(1)) ++ s_stft = torch.cat([s_stft_real, s_stft_imag], dim=1) ++ # 字典取值操作无法被dynamo编译,把decode内部的index拿到外面计算 ++ index = self.istft_params["n_fft"] // 2 + 1 ++ magnitude, phase = self.decode(x=speech_feat, s_stft=s_stft, index=index) ++ x = self._istft(magnitude, phase) ++ generated_speech = torch.clamp(x, -self.audio_limit, self.audio_limit) + return generated_speech, s +diff --git a/cosyvoice/llm/llm.py b/cosyvoice/llm/llm.py +index bbd3305..ce82157 100644 +--- a/cosyvoice/llm/llm.py ++++ b/cosyvoice/llm/llm.py +@@ -229,16 +229,17 @@ class Qwen2Encoder(torch.nn.Module): + super().__init__() + self.model = Qwen2ForCausalLM.from_pretrained(pretrain_path) + +- def forward_one_step(self, xs, masks, cache=None): +- input_masks = masks[:, -1, :] +- outs = self.model( +- inputs_embeds=xs, +- attention_mask=input_masks, +- output_hidden_states=True, +- return_dict=True, +- use_cache=True, +- past_key_values=cache, +- ) ++ def forward_one_step(self, xs, masks, prompt_length, cache=None): ++ with torch.no_grad(): ++ outs = self.model( ++ inputs_embeds=xs, ++ attention_mask=masks, ++ prompt_length=prompt_length, ++ output_hidden_states=True, ++ return_dict=True, ++ use_cache=True, ++ past_key_values=cache, ++ ) + xs = outs.hidden_states[-1] + new_cache = outs.past_key_values + return xs, new_cache +@@ -318,10 +319,17 @@ class Qwen2LM(TransformerLM): + # 5. step by step decode + out_tokens = [] + cache = None ++ input_length = lm_input.shape[1] + for i in range(max_len): ++ prompt_length = input_length + i ++ if i == 0: ++ masks = torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]), device=lm_input.device)).to(torch.bool).logical_not() ++ else: ++ masks = None + y_pred, cache = self.llm.forward_one_step(lm_input, +- masks=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]), device=lm_input.device)).to(torch.bool), +- cache=cache) ++ masks=masks, ++ prompt_length=prompt_length, ++ cache=cache) + logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1) + top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False).item() + if top_ids == self.speech_token_size: +@@ -331,7 +339,7 @@ class Qwen2LM(TransformerLM): + # in stream mode, yield token one by one + yield top_ids + out_tokens.append(top_ids) +- lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1) ++ lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1).detach().clone() + + @torch.inference_mode() + def inference_bistream( +diff --git a/cosyvoice/utils/common.py b/cosyvoice/utils/common.py +index 3e61a8c..f6c4346 100644 +--- a/cosyvoice/utils/common.py ++++ b/cosyvoice/utils/common.py +@@ -107,12 +107,33 @@ def init_weights(m, mean=0.0, std=0.01): + + # Repetition Aware Sampling in VALL-E 2 + def ras_sampling(weighted_scores, decoded_tokens, sampling, top_p=0.8, top_k=25, win_size=10, tau_r=0.1): +- top_ids = nucleus_sampling(weighted_scores, top_p=top_p, top_k=top_k) ++ top_ids = dst_sampling(weighted_scores, top_p=top_p, top_k=top_k) + rep_num = (torch.tensor(decoded_tokens[-win_size:]).to(weighted_scores.device) == top_ids).sum().item() + if rep_num >= win_size * tau_r: + top_ids = random_sampling(weighted_scores, decoded_tokens, sampling) + return top_ids + ++def dst_sampling(weighted_scores, top_p=0.8, top_k=25): ++ ++ sorted_value, sorted_idx = weighted_scores.softmax(dim=0).sort(descending=True, stable=True) ++ ++ cum_sum = torch.cumsum(sorted_value, dim=0) ++ n = sorted_value.size(0) ++ device = cum_sum.device ++ pre_cum_sum = torch.cat([torch.zeros(1, device=device), cum_sum[:-1]]) ++ ++ indices = torch.arange(n ,device=device) ++ condition = (pre_cum_sum < top_p) & (indices < top_k) ++ ++ max_i_tensor = torch.where(condition, indices, torch.tensor(-1, device=device)) ++ n_selected = max_i_tensor.max() + 1 ++ ++ selected_prob = sorted_value[:n_selected] ++ selected_indices = sorted_idx[:n_selected] ++ ++ top_ids = selected_indices[selected_prob.multinomial(1, replacement=True)] ++ ++ return top_ids + + def nucleus_sampling(weighted_scores, top_p=0.8, top_k=25): + prob, indices = [], [] +-- +2.37.1 + diff --git a/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/infer.py b/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/infer.py new file mode 100755 index 0000000000000000000000000000000000000000..8743a684fe66f71842af76bdde0c0c744de801fe --- /dev/null +++ b/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/infer.py @@ -0,0 +1,55 @@ +# Copyright (c) 2025 Huawei Technologies Co., Ltd +# [Software Name] is licensed under Mulan PSL v2. +# You can use this software according to the terms and conditions of the Mulan PSL v2. +# You may obtain a copy of Mulan PSL v2 at: +# http://license.coscl.org.cn/MulanPSL2 +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, +# EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, +# MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. +# See the Mulan PSL v2 for more details. + +import argparse +import torch +import torchaudio +import torch_npu +from torch_npu.contrib import transfer_to_npu +import torchair as tng +from torchair.configs.compiler_config import CompilerConfig +from cosyvoice.cli.cosyvoice import CosyVoice2 +from cosyvoice.utils.file_utils import load_wav + + +if __name__ == '__main__': + torch_npu.npu.set_compile_mode(jit_compile=False) + + parser = argparse.ArgumentParser(description="CosyVoice infer") + parser.add_argument("--model_path", type=str, help="model path") + parser.add_argument('--warm_up_times', default=2, type=int, help='warm up times') + parser.add_argument('--infer_count', default=20, type=int, help='infer loop count') + parser.add_argument('--stream', action="store_true", help='stream infer') + args = parser.parse_args() + + cosyvoice = CosyVoice2(args.model_path, load_om=True, fp16=True) + cosyvoice.model.llm.eval() + cosyvoice.model.llm.llm.model.model.half() + + # 对hift模型结构进行torchair图模式适配 + cosyvoice.model.hift.remove_weight_norm() + config = CompilerConfig() + config.experimental_config.frozen_parameter = True + config.experimental_config.tiling_schedule_optimize = True + npu_backend = tng.get_npu_backend(compiler_config=config) + cosyvoice.model.hift.decode = torch.compile(cosyvoice.model.hift.decode, dynamic=True, fullgraph=True, backend=npu_backend) + + + # 输入数据加载 + prompt_txt = '收到好友从远方寄来的生日礼物,那份意外的惊喜和深深的祝福,让我心中充满了甜蜜的快乐,笑容如花儿般绽放。' + + with torch.no_grad(): + print('warm up start') + for _ in range(args.warm_up_times): + next(cosyvoice.inference_sft(prompt_txt, '中文女', stream=args.stream)) + print('warm up end') + for _ in range(args.infer_count): + for i, j in enumerate(cosyvoice.inference_sft(prompt_txt, '中文女', stream=args.stream)): + torchaudio.save('sft_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate) \ No newline at end of file diff --git a/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/modeling_qwen2.py b/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/modeling_qwen2.py new file mode 100755 index 0000000000000000000000000000000000000000..0b9eef3138f237e0292c39c0f1cbe35fc967f450 --- /dev/null +++ b/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/modeling_qwen2.py @@ -0,0 +1,880 @@ +# Copyright (c) 2025 Huawei Technologies Co., Ltd +# [Software Name] is licensed under Mulan PSL v2. +# You can use this software according to the terms and conditions of the Mulan PSL v2. +# You may obtain a copy of Mulan PSL v2 at: +# http://license.coscl.org.cn/MulanPSL2 +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, +# EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, +# MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. +# See the Mulan PSL v2 for more details. +""" PyTorch Qwen2 model.""" +import inspect +import math +import warnings +from typing import List, Optional, Tuple, Union +import time +import torch +import torch_npu +import torchair as tng +from torchair.configs.compiler_config import CompilerConfig +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache +from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa +from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast +from ...modeling_utils import PreTrainedModel +from ...utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_2_available, + is_flash_attn_greater_or_equal_2_10, + logging, + replace_return_docstrings, +) +from .configuration_qwen2 import Qwen2Config + + +if is_flash_attn_2_available(): + from flash_attn import flash_attn_func, flash_attn_varlen_func + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa + + _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters) + + +logger = logging.get_logger(__name__) + + +_CHECKPOINT_FOR_DOC = "Qwen/Qwen2-7B-beta" +_CONFIG_FOR_DOC = "Qwen2Config" + +QWEN2_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "Qwen/Qwen2-7B-beta", +] + + +# Ascend优化:Add/Norm昇腾自定义融合算子 +class Qwen2RMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + Qwen2RMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, + hidden_states, + residual: Optional[torch.Tensor] = None): + if residual is None: + return torch_npu.npu_rms_norm(hidden_states, self.weight, self.variance_epsilon)[0], hidden_states + else: + y, _, x = torch_npu.npu_add_rms_norm(residual, hidden_states, self.weight, self.variance_epsilon) + return y, x + + +# Ascend优化:提前计算位置编码,无需在每层layer中重复计算 +class Qwen2RotaryEmbedding(nn.Module): + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): + super().__init__() + + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + # Build here to make `torch.jit.trace` work. + self._set_cos_sin_cache( + seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype() + ) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) + + freqs = torch.outer(t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) + + def forward(self, x=None, seq_len=None): + if x is None and seq_len is None: + return self.cos_cached, self.sin_cached + + return ( + self.cos_cached.to(dtype=x.dtype), + self.sin_cached.to(dtype=x.dtype), + ) + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., :x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2:] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin): + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class Qwen2MLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + + +class Qwen2Attention(nn.Module): + """ + Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer + and "Generating Long Sequences with Sparse Transformers". + """ + + def __init__(self, config: Qwen2Config, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will " + "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.is_causal = True + self.attention_dropout = config.attention_dropout + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) + + self.rotary_emb = Qwen2RotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) + + +# Ascend优化:PFA/IFA自定义算子替换,kv cache固定shape并在指定位置更新 +class Qwen2SdpaAttention(Qwen2Attention): + """ + Qwen2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from + `Qwen2Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to + SDPA API. + """ + + # 优化Attention部分逻辑,替换torch_npu算子 + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + updated_kv_positions: Optional[torch.LongTensor] = None, + kv_padding_size: Optional[torch.LongTensor] = None, + actual_seq_len: Optional[list] = None, + rotary_emb_cos: Optional[torch.Tensor] = None, + rotary_emb_sin: Optional[torch.Tensor] = None, + output_attentions: bool = False, + use_cache: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if output_attentions: + logger.warning_once( + "Qwen2Model is using Qwen2SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim) + + # 利用已经提前计算好的位置编码数据对q,k值进行更新 + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, + rotary_emb_cos.to(value_states.dtype), + rotary_emb_sin.to(value_states.dtype)) + + if use_cache and past_key_value is not None: + # 把计算好的kv值更新到kv cahce中 + tmp_ids = updated_kv_positions.reshape(-1) + torch_npu.scatter_update_(past_key_value.key_cache[self.layer_idx], tmp_ids, key_states, 1) + torch_npu.scatter_update_(past_key_value.value_cache[self.layer_idx], tmp_ids, value_states, 1) + kv_states = past_key_value[self.layer_idx] if q_len == 1 else (key_states, value_states) + key_states = kv_states[0] + value_states = kv_states[1] + + + if q_len > 1: + # prefill阶段利用PFA自定义算子执行计算,因为bs为1,mask固定为下三角全为0上三角全为负无穷的倒三角mask矩阵 + attn_output = torch_npu.npu_prompt_flash_attention(query_states, key_states.contiguous(), + value_states.contiguous(), num_heads=self.num_heads, + input_layout="BSND", + scale_value=1 / math.sqrt(self.head_dim), + pre_tokens=65535, next_tokens=0, + atten_mask=attention_mask, + num_key_value_heads=self.num_key_value_heads) + else: + # decode阶段利用IFA自定义算子执行计算,qkv的sequence都为1,该算子采用tiling下沉,视为静态算子,支持整图下发 + attn_output = torch_npu.npu_incre_flash_attention(query_states, key_states.contiguous(), + value_states.contiguous(), num_heads=self.num_heads, + input_layout="BSND", + scale_value=1 / math.sqrt(self.head_dim), + atten_mask=None, + actual_seq_lengths=actual_seq_len, + kv_padding_size=kv_padding_size, + num_key_value_heads=self.num_key_value_heads) + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_value + + +QWEN2_ATTENTION_CLASSES = { + "eager": Qwen2Attention, + "flash_attention_2": Qwen2FlashAttention2, + "sdpa": Qwen2SdpaAttention, +} + + +# Ascend优化:每层layer的前后rms替换为昇腾自定义算子 +class Qwen2DecoderLayer(nn.Module): + def __init__(self, config: Qwen2Config, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + if config.use_sliding_window and config._attn_implementation != "flash_attention_2": + logger.warning_once( + f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; " + "unexpected results may be encountered." + ) + self.self_attn = QWEN2_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx) + + self.mlp = Qwen2MLP(config) + self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + past_residual: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + updated_kv_positions: Optional[torch.LongTensor] = None, + kv_padding_size: Optional[torch.LongTensor] = None, + actual_seq_len: Optional[list] = None, + rotary_emb_cos: Optional[torch.Tensor] = None, + rotary_emb_sin: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. " + "Please make sure use `attention_mask` instead.`" + ) + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, sequence_length)` where padding elements are indicated by 0. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + """ + + # rms计算替换为昇腾自定义融合算子 + hidden_states, residual = self.input_layernorm(hidden_states, past_residual) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + updated_kv_positions=updated_kv_positions, + kv_padding_size=kv_padding_size, + actual_seq_len=actual_seq_len, + rotary_emb_cos=rotary_emb_cos, + rotary_emb_sin=rotary_emb_sin, + use_cache=use_cache, + ) + + # rms计算替换为昇腾自定义融合算子 + hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) + hidden_states = self.mlp(hidden_states) + + outputs = (residual, hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +@add_start_docstrings( + "The bare Qwen2 Model outputting raw hidden-states without any specific head on top.", + QWEN2_START_DOCSTRING, +) +class Qwen2PreTrainedModel(PreTrainedModel): + config_class = Qwen2Config + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["Qwen2DecoderLayer"] + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_cache_class = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +# Ascend优化:forward函数利用torchair编译为图模式,利用cache接口避免重复编译 +@add_start_docstrings( + "The bare Qwen2 Model outputting raw hidden-states without any specific head on top.", + QWEN2_START_DOCSTRING, +) +class Qwen2Model(Qwen2PreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Qwen2DecoderLayer`] + + Args: + config: Qwen2Config + """ + + def __init__(self, config: Qwen2Config): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + self.max_position_embeddings = config.max_position_embeddings + + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.rope_theta = config.rope_theta + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [Qwen2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self._attn_implementation = config._attn_implementation + self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = Qwen2RotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + # torchair编译参数,编译Qwen2Model的forward部分 + config = CompilerConfig() + config.experimental_config.frozen_parameter = True + # tiling下沉,主要针对IFA算子,使其算子tiling操作在AICPU上执行 + config.experimental_config.tiling_schedule_optimize = True + # torchair的cache编译,保证模型编译cache文件,避免重复推理 + self.cached_decode = tng.inference.cache_compile(self.decode, config=config) + self.cached_prefill = tng.inference.cache_compile(self.prefill, config=config) + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + def _prepare_decoder_rotary_cos_sin(self, position_ids): + cos, sin = self.rotary_emb() + f_position_ids = position_ids.flatten() + cos = torch.index_select(cos, 0, f_position_ids) + sin = torch.index_select(sin, 0, f_position_ids) + cos = cos.reshape(position_ids.size(0), position_ids.size(1), -1).unsqueeze(2) + sin = sin.reshape(position_ids.size(0), position_ids.size(1), -1).unsqueeze(2) + return cos, sin + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + updated_kv_positions: Optional[torch.LongTensor] = None, + kv_padding_size: Optional[torch.LongTensor] = None, + actual_seq_len: Optional[list] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + lm_head: Optional[function] = None + ) -> Union[Tuple, BaseModelOutputWithPast]: + # prefill和decode需要编译为两个不同的模型 + if inputs_embeds.size(1) > 1: + return self.cached_prefill( + input_ids, + attention_mask, + position_ids, + past_key_values, + updated_kv_positions, + kv_padding_size, + actual_seq_len, + inputs_embeds, + use_cache, + output_attentions, + output_hidden_states, + return_dict, + lm_head + ) + else: + return self.cached_decode( + input_ids, + attention_mask, + position_ids, + past_key_values, + updated_kv_positions, + kv_padding_size, + actual_seq_len, + inputs_embeds, + use_cache, + output_attentions, + output_hidden_states, + return_dict, + lm_head + ) + + def decode( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + updated_kv_positions: Optional[torch.LongTensor] = None, + kv_padding_size: Optional[torch.LongTensor] = None, + actual_seq_len: Optional[list] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + lm_head: Optional[function] = None + ): + return self._forward( + input_ids, + attention_mask, + position_ids, + past_key_values, + updated_kv_positions, + kv_padding_size, + actual_seq_len, + inputs_embeds, + use_cache, + output_attentions, + output_hidden_states, + return_dict, + lm_head + ) + + def prefill( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + updated_kv_positions: Optional[torch.LongTensor] = None, + kv_padding_size: Optional[torch.LongTensor] = None, + actual_seq_len: Optional[list] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + lm_head: Optional[function] = None + ): + return self._forward( + input_ids, + attention_mask, + position_ids, + past_key_values, + updated_kv_positions, + kv_padding_size, + actual_seq_len, + inputs_embeds, + use_cache, + output_attentions, + output_hidden_states, + return_dict, + lm_head + ) + + + @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING) + def _forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + updated_kv_positions: Optional[torch.LongTensor] = None, + kv_padding_size: Optional[torch.LongTensor] = None, + actual_seq_len: Optional[list] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + lm_head: Optional[function] = None + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + # prefill阶段初始化kv cache,decode阶段对kv cache进行更新 + # 固定kv cache为最大shape,避免内存的重复申请和拷贝,也保证了模型的静态shape,可整图下发推理 + if use_cache: + use_legacy_cache = not isinstance(past_key_values, Cache) + if use_legacy_cache: + kv_shape = ( + batch_size, self.config.max_position_embeddings, + self.config.num_key_value_heads, + self.config.hidden_size // self.config.num_attention_heads) + past_key_values = () + for _ in range(self.config.num_hidden_layers): + k_cache = torch.zeros(kv_shape, dtype=inputs_embeds.dtype, device=inputs_embeds.device) + v_cache = torch.zeros(kv_shape, dtype=inputs_embeds.dtype, device=inputs_embeds.device) + past_key_values += ((k_cache, v_cache),) + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + + past_key_values_length = self.max_position_embeddings if seq_length == 1 else 0 + + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange( + past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + ) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + else: + position_ids = position_ids.view(-1, seq_length).long() + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache: + is_padding_right = attention_mask[:, -1].sum().item() != batch_size + if is_padding_right: + raise ValueError( + "You are attempting to perform batched generation with padding_side='right'" + " this may lead to unexpected behaviour for Flash Attention version of Qwen2. Make sure to " + " call `tokenizer.padding_side = 'left'` before tokenizing the input. " + ) + + hidden_states = inputs_embeds + + # 此处统一计算位置编码,在每个layer中取对应位置的值 + rotary_emb_cos, rotary_emb_sin = self._prepare_decoder_rotary_cos_sin(position_ids) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + residual = None + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + # 执行layer层推理 + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + past_residual=residual, + position_ids=position_ids, + past_key_value=past_key_values, + updated_kv_positions=updated_kv_positions, + kv_padding_size=kv_padding_size, + actual_seq_len=actual_seq_len, + rotary_emb_cos=rotary_emb_cos, + rotary_emb_sin=rotary_emb_sin, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + residual = layer_outputs[0] + hidden_states = layer_outputs[1] + + if use_cache: + next_decoder_cache = layer_outputs[3 if output_attentions else 2] + + if output_attentions: + all_self_attns += (layer_outputs[2],) + + # norm计算,此处替换为昇腾融合算子 + hidden_states, _ = self.norm(hidden_states, residual) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = None + if use_cache: + next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache + + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + + out = BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + hidden_states = out[0] + # 由于logits最后也只取[:,-1,:],相当于只取最新seq位置上的数据,l + # 所以在全量的最后线性层计算可以只对最新的seq位置做计算,降低计算量 + bs, seq, hidden = hidden_states.size() + if seq > 1: + gather_index = torch.ones(bs, dtype=torch.int64, device=hidden_states.device) * (seq - 1) + gather_index = gather_index.unsqueeze(dim=1).unsqueeze(dim=2).repeat(1, 1, hidden) + hidden_states = torch.gather(hidden_states, 1, gather_index) + logits = lm_head(hidden_states) + logits = logits.float() + return out, logits + + +class Qwen2ForCausalLM(Qwen2PreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = Qwen2Model(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + prompt_length: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + 对CosyVoice2模型中使用的Qwen模型进行昇腾适配优化,具体优化点有: + 1. 固定KV CACHE大小,避免重复申请内存和拷贝 + 2. 替换部分算子为昇腾自定义算子 + 3. 首层计算位置编码避免重复计算 + 4. 在decode阶段,固定输入shape大小,保证整图下发 + + 模型有以下输入: + 1. attention_mask + 2. inputs_embeds:CosyVoice会把inputs_ids处理embeding后输入模型 + 3. past_key_values:kv cache,在每次推理后会进行更新 + 4. position_ids:位置id,在每次推理后会进行更新 + 5. prompt_length:实际输入长度,在prefill阶段为首token长度,后续每次推理长度加1 + """ + + # 每次推理前对输入数据进行昇腾适配处理,处理为昇腾自定义算子所需类型参数 + updated_kv_positions, past_key_values, position_ids, kv_padding_size, actual_seq_len = self.prepare_data(inputs_embeds, past_key_values, prompt_length) + + model_inputs = { + "inputs_embeds": inputs_embeds, + "past_key_values": past_key_values, + "position_ids": position_ids, + "kv_padding_size": kv_padding_size, + "actual_seq_len": actual_seq_len, + "attention_mask": attention_mask, + } + + # prefill阶段由于输出token长度不固定,为动态shape推理。decode阶段把输入固定为静态,保证整图静态推理。 + if inputs_embeds.shape[1] == 1: + self._mark_model_inputs_static(model_inputs) + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # 主要推理阶段,利用torchair编译为整图推理 + outputs, logits = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + updated_kv_positions=updated_kv_positions, + kv_padding_size=kv_padding_size, + actual_seq_len=actual_seq_len, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + lm_head=self.lm_head + ) + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + # Ascend优化:把数据输入处理为Ascend优化所需要的格式和类型 + def prepare_data(self, inputs_embeds, past_key_values, prompt_length): + bsz = inputs_embeds.shape[0] + seq_length = inputs_embeds.shape[1] + if past_key_values: + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + if seq_length > 1: + updated_kv_positions = torch.zeros(bsz, dtype=torch.long, device=inputs_embeds.device) + position_ids = None + else: + updated_kv_positions = torch.ones(bsz, dtype=torch.long, device=inputs_embeds.device) * (prompt_length - 1) + position_ids = torch.tensor([prompt_length], device=inputs_embeds.device) + + # ifa Computational optimization inputs + kv_padding_size = torch.tensor(self.config.max_position_embeddings - prompt_length, device=inputs_embeds.device) + actual_seq_len = ([prompt_length]) + + return updated_kv_positions, past_key_values, position_ids, kv_padding_size, actual_seq_len + + # Ascend优化:固定input shape,使能静态推理,模型整图下发 + def _mark_model_inputs_static(self, model_inputs): + for key, value in model_inputs.items(): + if key == "past_key_values" and value is not None: + for i in range(self.config.num_hidden_layers): + torch._dynamo.mark_static(value[i][0]) + torch._dynamo.mark_static(value[i][1]) + elif isinstance(value, torch.Tensor): + torch._dynamo.mark_static(value) + diff --git a/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/modify_onnx.py b/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/modify_onnx.py new file mode 100755 index 0000000000000000000000000000000000000000..d2a312fd332a0a75ea869ebb8ef1cc3c46156d59 --- /dev/null +++ b/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/modify_onnx.py @@ -0,0 +1,32 @@ +# Copyright (c) 2025 Huawei Technologies Co., Ltd +# [Software Name] is licensed under Mulan PSL v2. +# You can use this software according to the terms and conditions of the Mulan PSL v2. +# You may obtain a copy of Mulan PSL v2 at: +# http://license.coscl.org.cn/MulanPSL2 +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, +# EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, +# MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. +# See the Mulan PSL v2 for more details. + +import os +import sys +from auto_optimizer import OnnxGraph + + +def modify_speech_token(speech_token): + ReduceMax_list = speech_token.get_nodes("ReduceMax") + for node in ReduceMax_list: + if node.attrs['keepdims'] == 0: + out_nodes = speech_token.get_next_nodes(node.outputs[0]) + for out_node in out_nodes: + if out_node.op_type == "Unsqueeze": + node.attrs['keepdims'] = 1 + speech_token.remove(out_node.name) + return speech_token + +if __name__ == '__main__': + input_path = sys.argv[1] + speech_token = OnnxGraph.parse(os.path.join(input_path, "speech_tokenizer_v2.onnx")) + # om图模式不支持无shape输出,reducemax需要修改keepdim + modify_onnx = modify_speech_token(speech_token) + modify_onnx.save(os.path.join(input_path, "speech_token_md.onnx")) \ No newline at end of file diff --git a/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/requirements.txt b/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/requirements.txt new file mode 100755 index 0000000000000000000000000000000000000000..b2794ae2b99847a0023171c414f59ebdcc3c7592 --- /dev/null +++ b/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/requirements.txt @@ -0,0 +1,33 @@ +conformer==0.3.2 +deepspeed==0.14.2 +diffusers==0.27.2 +gdown==5.1.0 +gradio==4.32.2 +grpcio==1.57.0 +grpcio-tools==1.57.0 +huggingface-hub==0.23.5 +hydra-core==1.3.2 +HyperPyYAML==1.2.2 +inflect==7.3.1 +librosa==0.10.2 +lightning==2.2.4 +matplotlib==3.7.5 +modelscope==1.15.0 +networkx==3.1 +omegaconf==2.3.0 +onnx==1.16.0 +onnxruntime==1.16.0 +openai-whisper==20231117 +protobuf==4.25 +pydantic==2.7.0 +rich==13.7.1 +soundfile==0.12.1 +tensorboard==2.14.0 +torch==2.4.0 +torch_npu==2.4.0.post2 +torchaudio==2.4.0 +uvicorn==0.30.0 +wget==3.2 +fastapi==0.111.0 +fastapi-cli==0.0.4 +WeTextProcessing==1.0.3 \ No newline at end of file