diff --git a/ACL_PyTorch/built-in/audio/CosyVoice2/800I/diff_CosyVoice_800I.patch b/ACL_PyTorch/built-in/audio/CosyVoice2/800I/diff_CosyVoice_800I.patch index 6bec7233c6f96591e27b498c811a4d4b30de91e0..d5021894c1e390bc9f0f368bfe716683b6938fd4 100644 --- a/ACL_PyTorch/built-in/audio/CosyVoice2/800I/diff_CosyVoice_800I.patch +++ b/ACL_PyTorch/built-in/audio/CosyVoice2/800I/diff_CosyVoice_800I.patch @@ -1,8 +1,8 @@ diff --git a/cosyvoice/cli/cosyvoice.py b/cosyvoice/cli/cosyvoice.py -index e2d62e2..dccea41 100644 +index e2d62e2..a0512a4 100644 --- a/cosyvoice/cli/cosyvoice.py +++ b/cosyvoice/cli/cosyvoice.py -@@ -13,11 +13,14 @@ +@@ -13,11 +13,15 @@ # limitations under the License. import os import time @@ -13,37 +13,48 @@ index e2d62e2..dccea41 100644 from hyperpyyaml import load_hyperpyyaml from modelscope import snapshot_download import torch ++import acl +from ais_bench.infer.interface import InferSession from cosyvoice.cli.frontend import CosyVoiceFrontEnd from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model from cosyvoice.utils.file_utils import logging -@@ -126,7 +129,7 @@ class CosyVoice: - +@@ -126,7 +130,7 @@ class CosyVoice: + class CosyVoice2(CosyVoice): - + - def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False): + def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False, load_om=False): self.instruct = True if '-Instruct' in model_dir else False self.model_dir = model_dir self.fp16 = fp16 -@@ -155,6 +158,16 @@ class CosyVoice2(CosyVoice): +@@ -155,6 +159,26 @@ class CosyVoice2(CosyVoice): self.model.load_trt('{}/flow.decoder.estimator.{}.mygpu.plan'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'), '{}/flow.decoder.estimator.fp32.onnx'.format(model_dir), self.fp16) + if load_om: ++ soc_version = acl.get_soc_name() ++ context = None ++ if '910B3' in soc_version: ++ context, ret = acl.rt.get_context() ++ if ret: ++ raise RuntimeError(f"Get context failed, retcode is {ret}.") + arch = platform.machine() + system = platform.system().lower() + flow_om = InferSession(0, '{}/flow_{}_{}.om'.format(model_dir, system ,arch)) + flow_om_static = InferSession(0, '{}/flow_static.om'.format(model_dir)) + speech_om = InferSession(0, '{}/speech_{}_{}.om'.format(model_dir, system ,arch)) ++ if '910B3' in soc_version: ++ ret = acl.rt.set_context(context) ++ if ret: ++ raise RuntimeError(f"Set context failed, retcode is {ret}.") + self.frontend.speech_om = speech_om + self.frontend.flow_om = flow_om + self.model.flow.decoder.flow_om_static = flow_om_static + self.model.flow.decoder.flow_om = flow_om del configs - + def inference_instruct(self, *args, **kwargs): -@@ -171,3 +184,19 @@ class CosyVoice2(CosyVoice): +@@ -171,3 +195,19 @@ class CosyVoice2(CosyVoice): logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len)) yield model_output start_time = time.time() @@ -73,7 +84,7 @@ index 6e10f00..25ad767 100644 self.inflect_parser = inflect.engine() + self.speech_om = None + self.flow_om = None - + def _extract_text_token(self, text): if isinstance(text, Generator): @@ -92,11 +94,16 @@ class CosyVoiceFrontEnd: @@ -104,7 +115,7 @@ index 9ebf8cb..a8775a1 100644 +++ b/cosyvoice/cli/model.py @@ -99,7 +99,7 @@ class CosyVoiceModel: self.flow.decoder.estimator = self.flow.decoder.estimator_engine.create_execution_context() - + def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid): - with self.llm_context: + with self.llm_context(): @@ -126,7 +137,7 @@ index 9ebf8cb..a8775a1 100644 self.tts_speech_token_dict = {} self.llm_end_dict = {} self.hift_cache_dict = {} - + + # add for support streaming input + self.first_chunk_size = 20 + self.token_offset_dict = {} @@ -274,15 +285,15 @@ index 6a60f6d..fbe7545 100644 import torch.nn.functional as F +import numpy as np from matcha.models.components.flow_matching import BASECFM - - + + @@ -32,6 +33,8 @@ class ConditionalCFM(BASECFM): # Just change the architecture of the estimator here self.estimator = estimator self.lock = threading.Lock() + self.flow_om = None + self.flow_om_static = None - + @torch.inference_mode() def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, prompt_len=0, flow_cache=torch.zeros(1, 80, 0, 2)): @@ -105,12 +108,26 @@ class ConditionalCFM(BASECFM): @@ -329,17 +340,17 @@ index c47bf05..7f3e4ae 100644 +from torch.nn.utils.parametrize import remove_parametrizations from torch.nn.utils.parametrizations import weight_norm from torch.distributions.uniform import Uniform - + @@ -99,8 +100,8 @@ class ResBlock(torch.nn.Module): - + def remove_weight_norm(self): for idx in range(len(self.convs1)): - remove_weight_norm(self.convs1[idx]) - remove_weight_norm(self.convs2[idx]) + remove_parametrizations(self.convs1[idx], "weight") + remove_parametrizations(self.convs2[idx], "weight") - - + + class SineGen(torch.nn.Module): @@ -319,14 +320,11 @@ class HiFTGenerator(nn.Module): def remove_weight_norm(self): @@ -358,41 +369,41 @@ index c47bf05..7f3e4ae 100644 + remove_parametrizations(self.conv_post, 'weight') for l in self.source_resblocks: l.remove_weight_norm() - + @@ -346,9 +344,7 @@ class HiFTGenerator(nn.Module): self.istft_params["n_fft"], window=self.stft_window.to(magnitude.device)) return inverse_transform - + - def decode(self, x: torch.Tensor, s: torch.Tensor = torch.zeros(1, 1, 0)) -> torch.Tensor: - s_stft_real, s_stft_imag = self._stft(s.squeeze(1)) - s_stft = torch.cat([s_stft_real, s_stft_imag], dim=1) + def decode(self, x: torch.Tensor, s_stft: torch.Tensor, index: torch.int) -> torch.Tensor: - + x = self.conv_pre(x) for i in range(self.num_upsamples): @@ -356,7 +352,7 @@ class HiFTGenerator(nn.Module): x = self.ups[i](x) - + if i == self.num_upsamples - 1: - x = self.reflection_pad(x) + x = torch.cat((x, x[:,:,-2:-1]), -1) - + # fusion si = self.source_downs[i](s_stft) @@ -373,12 +369,10 @@ class HiFTGenerator(nn.Module): - + x = F.leaky_relu(x) x = self.conv_post(x) - magnitude = torch.exp(x[:, :self.istft_params["n_fft"] // 2 + 1, :]) - phase = torch.sin(x[:, self.istft_params["n_fft"] // 2 + 1:, :]) # actually, sin is redundancy + magnitude = torch.exp(x[:, :index, :]) + phase = torch.sin(x[:, index:, :]) # actually, sin is redundancy - + - x = self._istft(magnitude, phase) - x = torch.clamp(x, -self.audio_limit, self.audio_limit) - return x + return magnitude, phase - + def forward( self, @@ -407,5 +401,12 @@ class HiFTGenerator(nn.Module): @@ -416,7 +427,7 @@ index bbd3305..7eb32ad 100644 @@ -229,16 +229,17 @@ class Qwen2Encoder(torch.nn.Module): super().__init__() self.model = Qwen2ForCausalLM.from_pretrained(pretrain_path) - + - def forward_one_step(self, xs, masks, cache=None): - input_masks = masks[:, -1, :] - outs = self.model( @@ -444,7 +455,7 @@ index bbd3305..7eb32ad 100644 @@ -283,6 +284,15 @@ class Qwen2LM(TransformerLM): self.sampling = sampling self.mix_ratio = mix_ratio - + + # 5. added for support streaming input + self.prompt_speech_token_emb_dict = {} + self.lm_input_dict = {} @@ -481,7 +492,7 @@ index bbd3305..7eb32ad 100644 out_tokens.append(top_ids) - lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1) + lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1).detach().clone() - + @torch.inference_mode() def inference_bistream( @@ -432,3 +449,144 @@ class Qwen2LM(TransformerLM): @@ -635,7 +646,7 @@ index 3e61a8c..d316b92 100644 --- a/cosyvoice/utils/common.py +++ b/cosyvoice/utils/common.py @@ -107,12 +107,33 @@ def init_weights(m, mean=0.0, std=0.01): - + # Repetition Aware Sampling in VALL-E 2 def ras_sampling(weighted_scores, decoded_tokens, sampling, top_p=0.8, top_k=25, win_size=10, tau_r=0.1): - top_ids = nucleus_sampling(weighted_scores, top_p=top_p, top_k=top_k) @@ -644,7 +655,7 @@ index 3e61a8c..d316b92 100644 if rep_num >= win_size * tau_r: top_ids = random_sampling(weighted_scores, decoded_tokens, sampling) return top_ids - + +def dst_sampling(weighted_scores, top_p=0.8, top_k=25): + + sorted_value, sorted_idx = weighted_scores.softmax(dim=0).sort(descending=True, stable=True) @@ -666,6 +677,6 @@ index 3e61a8c..d316b92 100644 + top_ids = selected_indices[selected_prob.multinomial(1, replacement=True)] + + return top_ids - + def nucleus_sampling(weighted_scores, top_p=0.8, top_k=25): prob, indices = [], [] diff --git a/ACL_PyTorch/built-in/audio/CosyVoice2/README.md b/ACL_PyTorch/built-in/audio/CosyVoice2/README.md index 833da1e741a1ff1760ceabf9c7ad84117537f8f9..804dadcfe5496ad83eebf350629555e017d4b281 100644 --- a/ACL_PyTorch/built-in/audio/CosyVoice2/README.md +++ b/ACL_PyTorch/built-in/audio/CosyVoice2/README.md @@ -43,7 +43,7 @@ ## 获取本仓源码 ``` git clone https://gitee.com/ascend/ModelZoo-PyTorch.git -cd ModelZoo-PyTorch/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2 +cd ModelZoo-PyTorch/ACL_PyTorch/built-in/audio/CosyVoice2 ``` ## 获取源码 @@ -55,6 +55,7 @@ cd ModelZoo-PyTorch/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2 cd CosyVoice git reset --hard fd45708 git submodule update --init --recursive + # 根据当前使用机型,叠加patch。如果当前使用机型为313T 800T A2,和800I共用patch文件 git apply ../${platform}/diff_CosyVoice_${platform}.patch # 将infer.py复制到CosyVoice中 cp ../infer.py ./ @@ -63,24 +64,23 @@ cd ModelZoo-PyTorch/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2 cd transformers git checkout v4.37.0 cd .. - # 将modeling_qwen模型文件替换到transformers仓内 - mv ../${platform}/modeling_qwen2.py ./transformers/src/transformers/models/qwen2 + # 将modeling_qwen模型文件替换到transformers仓内。800T A2和800I A2共用modeling_qwen2.py。 + cp ../${platform}/modeling_qwen2.py ./transformers/src/transformers/models/qwen2 ``` 文件目录结构大致如下: ```text - 📁 CosyVoice/ - ├── 📁 CosyVoice2/ - | |── 📁 300I - | |── 📄 diff_CosyVoice_300I.patch - | |── 📄 modeling_qwen2.py - | |── 📁 800I - | |── 📄 diff_CosyVoice_800I.patch - | |── 📄 modeling_qwen2.py - | |── 📁 CosyVoice - | |── 📁 cosyVoice源码文件 # cosyVoice的源码文件,此处不一一列举 - │ ├── 📁 CosyVoice-0.5B/ # 权重文件 - │ ├── 📁 transformers/ # transformers库,里面修改modeling_qwen2.py文件 + 📁 CosyVoice2/ + |── 📁 300I + |── 📄 diff_CosyVoice_300I.patch + |── 📄 modeling_qwen2.py + |── 📁 800I + |── 📄 diff_CosyVoice_800I.patch + |── 📄 modeling_qwen2.py + |── 📁 CosyVoice + |── 📁 cosyVoice源码文件 # cosyVoice的源码文件,此处不一一列举 + ├── 📁 CosyVoice-0.5B/ # 权重文件 + ├── 📁 transformers/ # transformers库,里面修改modeling_qwen2.py文件 │── 📄 requirements.txt # 依赖库 |── 📄 infer.py # 推理脚本 └── 📄 modify_onnx.py # 模型转换脚本 @@ -91,7 +91,7 @@ cd ModelZoo-PyTorch/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2 pip3 install -r ../requirements.txt apt-get install sox # centos版本 yum install sox ``` - 注:如果遇到无法安装WeTextProcessing的场景,可以参考以下方法手动安装编译 + 注:如果遇到无法安装WeTextProcessing的场景,例如提示安装pyinit报错,可以参考以下方法手动安装编译 ```bash # 下载安装包并解压 wget https://www.openfst.org/twiki/pub/FST/FstDownload/openfst-1.8.3.tar.gz @@ -110,7 +110,8 @@ cd ModelZoo-PyTorch/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2 3. 安装msit工具 - 参考[msit](https://gitee.com/ascend/msit)安装工具中的benchmark和surgeon组件。(未安装会提示 ais_bench 导入失败报错) + 参考[msit](https://gitee.com/ascend/msit/blob/master/msit/docs/install/README.md)安装工具中的benchmark和surgeon组件。(未安装会提示 ais_bench 导入失败报错) + 推荐使用git clone源码方式安装msit组件,否则推理过程中易出现报错The stream is not in the current context. 4. 获取权重数据 @@ -200,3 +201,27 @@ cd ModelZoo-PyTorch/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2 | cosyvoice |800I A2|0.28| | cosyvoice |300I DUO|0.75| + +# FAQ + 1. 环境安装依赖 + + (1)安装requirements.txt中的python库时,提示pynini编译失败: + pynini是WeTextProcessing的安装依赖项,编译报错时,需要按照获取源码章节,第2小节,手动编译安装WeTextProcessing。 + + (2)如提示未安装tokenizers库或版本冲突,可使用0.15.1版本tokenizers。 + + 2. 如在Openeular系统运行模型推理的过程中提示,fatal error: 'cstdint' file not found: + + 确保gcc,g++已安装成功 + 导入如下环境变量 + export CPLUS_INCLUDE_PATH=/usr/include/c++/12:/usr/include/c++/12/aarch64-openEuler-linux:$CPLUS_INCLUDE_PATH + + 3. 推理过程需确保ATC转换生成OM文件的过程,和推理过程的CANN版本保持一致。 + + 4. ATC转换时,提示Soc version ins invalid. + + atc命令--soc_version,需加入Ascend前缀,如--soc_version=Ascend310P3,具体型号以npu-smi info查询结果为准。 + + 5. 运行modify_onnx.py时,如提示ModuleNotFoundError: No module named 'auto_optimizer': + + 需先安装[msit](https://gitee.com/ascend/msit)工具。 diff --git a/ACL_PyTorch/built-in/audio/CosyVoice2/requirements.txt b/ACL_PyTorch/built-in/audio/CosyVoice2/requirements.txt index fb8c778f6249f1649090f99c6fce4565371e8a61..eda3d5f19e155eb85eec8ec0cf65e25c34b9c8cf 100644 --- a/ACL_PyTorch/built-in/audio/CosyVoice2/requirements.txt +++ b/ACL_PyTorch/built-in/audio/CosyVoice2/requirements.txt @@ -26,7 +26,7 @@ soundfile==0.12.1 tensorboard==2.14.0 torch==2.3.1 torch_npu==2.3.1.post6 -torchaudio==2.4.0 +torchaudio==2.3.1 uvicorn==0.30.0 wget==3.2 fastapi==0.111.0