diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable-audio-open-1.0/diffusers/README.md b/MindIE/MindIE-Torch/built-in/foundation/stable-audio-open-1.0/diffusers/README.md
index 61a1864ca2c60856b12ad9f88b761fce72fcdb88..2ddb2adbf744ff20c0c079d3ac8deb5d14707d25 100644
--- a/MindIE/MindIE-Torch/built-in/foundation/stable-audio-open-1.0/diffusers/README.md
+++ b/MindIE/MindIE-Torch/built-in/foundation/stable-audio-open-1.0/diffusers/README.md
@@ -1,4 +1,4 @@
-# stable-audio-open-1.0模型-推理指导
+# stable-audio-open-1.0模型-diffusers方式推理指导
- [概述](#ZH-CN_TOPIC_0000001172161501)
@@ -164,8 +164,8 @@
--audio_end_in_s 10 10 47 \
--num_waveforms_per_prompt 1 \
--guidance_scale 7 \
- --device 0 \
- --save_dir ./result
+ --save_dir ./results \
+ --device 0
```
参数说明:
@@ -173,11 +173,11 @@
- --output_dir:存放导出模型的目录。
- --prompt_file:提示词文件。
- --num_inference_steps: 语音生成迭代次数。
- - --save_dir:生成语音的存放目录。
- - --device:推理设备ID。
- --audio_end_in_s:生成语音的时长,如不输入则默认生成10s。
- --num_waveforms_per_prompt:一个提示词生成的语音数量。
- --guidance_scale:音频生成质量与准确度系数。
+ - --save_dir:生成语音的存放目录。
+ - --device:推理设备ID。
执行完成后在`./results`目录下生成推理语音,语音生成顺序与文本中prompt顺序保持一致,并在终端显示推理时间。
diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable-audio-open-1.0/diffusers/precision_brownian_interval.patch b/MindIE/MindIE-Torch/built-in/foundation/stable-audio-open-1.0/diffusers/precision_brownian_interval.patch
index fcaca7605d897b92393d71c6cb3612a742dfdc82..d9d94e58016f1ae0d0ca7a347759abe2d25907e7 100644
--- a/MindIE/MindIE-Torch/built-in/foundation/stable-audio-open-1.0/diffusers/precision_brownian_interval.patch
+++ b/MindIE/MindIE-Torch/built-in/foundation/stable-audio-open-1.0/diffusers/precision_brownian_interval.patch
@@ -8,3 +8,4 @@
- return torch.randn(size, dtype=dtype, device=device, generator=generator)
+ torch.manual_seed(int(seed))
+ return torch.randn(size, dtype=dtype, device="cpu").to(device)
+
diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable-audio-open-1.0/diffusers/prompts.txt b/MindIE/MindIE-Torch/built-in/foundation/stable-audio-open-1.0/diffusers/prompts.txt
new file mode 100644
index 0000000000000000000000000000000000000000..e1c7734ef9c418f15b6c67c338c81f2cb39b1e7e
--- /dev/null
+++ b/MindIE/MindIE-Torch/built-in/foundation/stable-audio-open-1.0/diffusers/prompts.txt
@@ -0,0 +1,3 @@
+Berlin techno, rave, drum machine, kick, ARP synthesizer, dark, moody, hypnotic, evolving, 135BPM. LOOP.
+Uplifting acoustic loop. 120 BPM.
+Disco, Driving Drum Machine, Synthesizer, Bass, Piano, Guitars, Instrumental, Clubby, Euphoric, Chicago, New York, 115 BPM.
\ No newline at end of file
diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable-audio-open-1.0/stable-audio-tools/README.md b/MindIE/MindIE-Torch/built-in/foundation/stable-audio-open-1.0/stable-audio-tools/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..99af1a0d3c6d6a5e38b35077e90585bae161a744
--- /dev/null
+++ b/MindIE/MindIE-Torch/built-in/foundation/stable-audio-open-1.0/stable-audio-tools/README.md
@@ -0,0 +1,151 @@
+# stable-audio-open-1.0模型-stable-audio-tools方式推理指导
+
+- [概述](#ZH-CN_TOPIC_0000001172161501)
+
+
+- [推理环境准备](#ZH-CN_TOPIC_0000001126281702)
+
+- [快速上手](#ZH-CN_TOPIC_0000001126281700)
+
+ - [获取源码](#section4622531142816)
+ - [模型推理](#section741711594517)
+
+- [模型推理性能&精度](#ZH-CN_TOPIC_0000001172201573)
+
+# 概述
+
+ [此处获得](https://huggingface.co/stabilityai/stable-audio-open-1.0)
+
+- 参考实现:
+ ```bash
+ # StableAudioOpen1.0
+ https://huggingface.co/stabilityai/stable-audio-open-1.0
+ ```
+
+# 推理环境准备
+
+- 该模型需要以下插件与驱动
+
+ **表 1** 版本配套表
+
+ | 配套 | 版本 | 环境准备指导 |
+ | ----- | ----- |-----|
+ | Python | 3.10.2 | - |
+ | torch | 2.1.0 | - |
+
+该模型性能受CPU规格影响,建议使用64核CPU(arm)以复现性能
+
+# 快速上手
+## 获取源码
+1. 安装依赖。
+ ```bash
+ pip3 install -r requirements.txt
+ ```
+
+2. 安装mindie包
+
+ ```bash
+ # 安装mindie
+ source /usr/local/Ascend/ascend-toolkit/set_env.sh
+ chmod +x ./Ascend-mindie_xxx.run
+ ./Ascend-mindie_xxx.run --install
+ source /usr/local/Ascend/mindie/set_env.sh
+ ```
+
+3. 代码修改
+
+ 执行命令:
+ ```bash
+ python3 conditioners_patch.py
+ python3 pretrained_patch.py
+ ```
+
+## 模型推理
+
+1. 模型准备。
+ 1. 获取模型权重
+
+ 可提前下载权重,以避免执行后面步骤时可能会出现下载失败。
+
+ ```bash
+ # 需要使用 git-lfs (https://git-lfs.com)
+ git lfs install
+
+ # 下载stable-audio-open-1.0权重
+ git clone https://huggingface.co/stabilityai/stable-audio-open-1.0
+ ```
+
+ 2. 设置模型权重的路径。
+ ```bash
+ # stable-audio-open-1.0 (执行时下载权重)
+ model_base="stabilityai/stable-audio-open-1.0"
+
+ # stable-audio-open-1.0 (使用上一步下载的权重)
+ model_base="./stable-audio-open-1.0"
+ ```
+
+ 3. 获取T5模型权重(可选)
+
+ 推理过程中会自动从huggingface下载T5-base的模型权重,若希望以加载本地T5-base模型权重方式进行推理,请将`model_base`路径下的`tokenizer`和`text_encoder`文件夹复制到推理代码的执行路径中。
+
+
+2. 开始推理验证。
+
+ 1. 开启cpu高性能模式
+ ```bash
+ echo performance |tee /sys/devices/system/cpu/cpu*/cpufreq/scaling_governor
+ sysctl -w vm.swappiness=0
+ sysctl -w kernel.numa_balancing=0
+ ```
+
+ 2. 安装绑核工具
+ ```bash
+ apt-get update
+ apt-get install numactl
+ ```
+ 查询卡的NUMA node
+ ```shell
+ lspci -vs bus-id
+ ```
+ bus-id可通过npu-smi info获得,查询到NUMA node,在推理命令前加上对应的数字
+
+ 可通过lscpu获得NUMA node对应的CPU核数
+ ```shell
+ NUMA node0: 0-23
+ NUMA node1: 24-47
+ NUMA node2: 48-71
+ NUMA node3: 72-95
+ ```
+ 当前查到NUMA node是0,对应0-23,推荐绑定其中单核以获得更好的性能。
+
+ 3. 执行推理脚本。
+ ```bash
+ numactl -C 0-23 python3 stable_audio_open_tools_pipeline.py \
+ --model ${model_base} \
+ --prompt_file ./prompts.txt \
+ --num_inference_steps 100 \
+ --seconds_total 10 10 47 \
+ --save_dir ./results \
+ --device 0
+ ```
+
+ 参数说明:
+ - --model:模型权重路径。
+ - --prompt_file:提示词文件。
+ - --num_inference_steps: 语音生成迭代次数。
+ - --seconds_total:生成语音的时长,与prompts.txt中的prompt对应,如不输入则默认生成10s。
+ - --save_dir:生成语音的存放目录。
+ - --device:推理设备ID。
+
+ 执行完成后在`./results`目录下生成推理语音,语音生成顺序与文本中prompt顺序保持一致,并在终端显示推理时间。
+
+
+
+# 模型推理性能&精度
+性能参考下列数据。
+
+### Stable-Audio-Open-1.0
+
+| 硬件形态 | 迭代次数 | 平均耗时|
+| :------: |:----:|:----:|
+| A2 | 100 | 14.711s |
\ No newline at end of file
diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable-audio-open-1.0/stable-audio-tools/conditioners.patch b/MindIE/MindIE-Torch/built-in/foundation/stable-audio-open-1.0/stable-audio-tools/conditioners.patch
new file mode 100644
index 0000000000000000000000000000000000000000..c61a74932a4584ef1cb4dabe31f8c370811f1bf1
--- /dev/null
+++ b/MindIE/MindIE-Torch/built-in/foundation/stable-audio-open-1.0/stable-audio-tools/conditioners.patch
@@ -0,0 +1,24 @@
+--- conditioners.py 2024-09-30 15:31:32.480360700 +0800
++++ conditioners_patch.py 2024-09-30 18:20:43.344830200 +0800
+@@ -280,10 +280,17 @@
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore")
+ try:
+- # self.tokenizer = T5Tokenizer.from_pretrained(t5_model_name, model_max_length = max_length)
+- # model = T5EncoderModel.from_pretrained(t5_model_name, max_length=max_length).train(enable_grad).requires_grad_(enable_grad)
+- self.tokenizer = AutoTokenizer.from_pretrained(t5_model_name)
+- model = T5EncoderModel.from_pretrained(t5_model_name).train(enable_grad).requires_grad_(enable_grad).to(torch.float16)
++ import os
++ tokenizer_path = os.path.join(os.getcwd(), "tokenizer")
++ text_encoder_path = os.path.join(os.getcwd(), "text_encoder")
++ if os.path.exists(tokenizer_path) and os.path.exists(text_encoder_path):
++ print("From local import T5-base . . .")
++ self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
++ model = T5EncoderModel.from_pretrained(text_encoder_path).train(enable_grad).requires_grad_(enable_grad).to(torch.float16)
++ else:
++ print("From HuggingFace download T5-base . . .")
++ self.tokenizer = AutoTokenizer.from_pretrained(t5_model_name)
++ model = T5EncoderModel.from_pretrained(t5_model_name).train(enable_grad).requires_grad_(enable_grad).to(torch.float16)
+ finally:
+ logging.disable(previous_level)
+
diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable-audio-open-1.0/stable-audio-tools/conditioners_patch.py b/MindIE/MindIE-Torch/built-in/foundation/stable-audio-open-1.0/stable-audio-tools/conditioners_patch.py
new file mode 100644
index 0000000000000000000000000000000000000000..71db741779f701c76d73093bb735523fe01200d1
--- /dev/null
+++ b/MindIE/MindIE-Torch/built-in/foundation/stable-audio-open-1.0/stable-audio-tools/conditioners_patch.py
@@ -0,0 +1,12 @@
+import os
+import stable_audio_tools
+
+
+def main():
+ stable_audio_tools_path = stable_audio_tools.__path__
+
+ os.system(f'patch -p0 {stable_audio_tools_path[0]}/models/conditioners.py conditioners.patch')
+
+
+if __name__ == '__main__':
+ main()
\ No newline at end of file
diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable-audio-open-1.0/stable-audio-tools/pretrained.patch b/MindIE/MindIE-Torch/built-in/foundation/stable-audio-open-1.0/stable-audio-tools/pretrained.patch
new file mode 100644
index 0000000000000000000000000000000000000000..f51e6a1d90f1ff875f2fee8a1fde06a21b7f1eb3
--- /dev/null
+++ b/MindIE/MindIE-Torch/built-in/foundation/stable-audio-open-1.0/stable-audio-tools/pretrained.patch
@@ -0,0 +1,29 @@
+--- pretrained.py 2024-09-30 15:31:40.672485200 +0800
++++ pretrained_patch.py 2024-10-07 14:54:18.756960100 +0800
+@@ -1,4 +1,5 @@
+ import json
++import os
+
+ from .factory import create_model_from_config
+ from .utils import load_ckpt_state_dict
+@@ -7,7 +8,7 @@
+
+ def get_pretrained_model(name: str):
+
+- model_config_path = hf_hub_download(name, filename="model_config.json", repo_type='model')
++ model_config_path = os.path.join(name, "model_config.json")
+
+ with open(model_config_path) as f:
+ model_config = json.load(f)
+@@ -15,10 +16,7 @@
+ model = create_model_from_config(model_config)
+
+ # Try to download the model.safetensors file first, if it doesn't exist, download the model.ckpt file
+- try:
+- model_ckpt_path = hf_hub_download(name, filename="model.safetensors", repo_type='model')
+- except Exception as e:
+- model_ckpt_path = hf_hub_download(name, filename="model.ckpt", repo_type='model')
++ model_ckpt_path = os.path.join(name, "model.safetensors")
+
+ model.load_state_dict(load_ckpt_state_dict(model_ckpt_path))
+
diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable-audio-open-1.0/stable-audio-tools/pretrained_patch.py b/MindIE/MindIE-Torch/built-in/foundation/stable-audio-open-1.0/stable-audio-tools/pretrained_patch.py
new file mode 100644
index 0000000000000000000000000000000000000000..4abdad47b0b49a4263f674bc5d9a17768602ac66
--- /dev/null
+++ b/MindIE/MindIE-Torch/built-in/foundation/stable-audio-open-1.0/stable-audio-tools/pretrained_patch.py
@@ -0,0 +1,12 @@
+import os
+import stable_audio_tools
+
+
+def main():
+ stable_audio_tools_path = stable_audio_tools.__path__
+
+ os.system(f'patch -p0 {stable_audio_tools_path[0]}/models/pretrained.py pretrained.patch')
+
+
+if __name__ == '__main__':
+ main()
\ No newline at end of file
diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable-audio-open-1.0/stable-audio-tools/prompts.txt b/MindIE/MindIE-Torch/built-in/foundation/stable-audio-open-1.0/stable-audio-tools/prompts.txt
new file mode 100644
index 0000000000000000000000000000000000000000..e1c7734ef9c418f15b6c67c338c81f2cb39b1e7e
--- /dev/null
+++ b/MindIE/MindIE-Torch/built-in/foundation/stable-audio-open-1.0/stable-audio-tools/prompts.txt
@@ -0,0 +1,3 @@
+Berlin techno, rave, drum machine, kick, ARP synthesizer, dark, moody, hypnotic, evolving, 135BPM. LOOP.
+Uplifting acoustic loop. 120 BPM.
+Disco, Driving Drum Machine, Synthesizer, Bass, Piano, Guitars, Instrumental, Clubby, Euphoric, Chicago, New York, 115 BPM.
\ No newline at end of file
diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable-audio-open-1.0/stable-audio-tools/requirements.txt b/MindIE/MindIE-Torch/built-in/foundation/stable-audio-open-1.0/stable-audio-tools/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..df8e40939ae837bb785620795a23d154ba3fa37d
--- /dev/null
+++ b/MindIE/MindIE-Torch/built-in/foundation/stable-audio-open-1.0/stable-audio-tools/requirements.txt
@@ -0,0 +1,5 @@
+torch==2.1.0
+torchaudio==2.1.0
+stable_audio_tools==0.0.16
+transformers==4.40.0
+torch_npu==2.1.0.post6
\ No newline at end of file
diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable-audio-open-1.0/stable-audio-tools/stable_audio_open_tools_pipeline.py b/MindIE/MindIE-Torch/built-in/foundation/stable-audio-open-1.0/stable-audio-tools/stable_audio_open_tools_pipeline.py
new file mode 100644
index 0000000000000000000000000000000000000000..192879312f02842f0aed6da4e3cc512e5d509fe6
--- /dev/null
+++ b/MindIE/MindIE-Torch/built-in/foundation/stable-audio-open-1.0/stable-audio-tools/stable_audio_open_tools_pipeline.py
@@ -0,0 +1,121 @@
+import torch
+import torch_npu
+import sys
+import time
+import os
+import argparse
+from safetensors.torch import load_file
+import torchaudio
+from einops import rearrange
+from stable_audio_tools import get_pretrained_model
+from stable_audio_tools.inference.generation import generate_diffusion_cond
+
+def parse_arguments():
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--prompt_file",
+ type=str,
+ default="./prompts.txt",
+ help="The prompts file to guide audio generation.",
+ )
+ parser.add_argument(
+ "--num_inference_steps",
+ type=int,
+ default=100,
+ help="The number of denoising steps. More denoising steps usually lead to a higher quality audio at the expense of slower inference.",
+ )
+ parser.add_argument(
+ "--model",
+ type=str,
+ default="./stable-audio-open-1.0",
+ help="The path of stable-audio-open-1.0.",
+ )
+ parser.add_argument(
+ "--seconds_total",
+ nargs='+',
+ default=[10],
+ help="Audio end index in seconds.",
+ )
+ parser.add_argument(
+ "--device",
+ type=int,
+ default=0,
+ help="NPU device id.",
+ )
+ parser.add_argument(
+ "--save_dir",
+ type=str,
+ default="./results",
+ help="Path to save result audio files.",
+ )
+ return parser.parse_args()
+
+def main():
+ args = parse_arguments()
+ save_dir = args.save_dir
+ if not os.path.exists(save_dir):
+ os.makedirs(save_dir)
+
+ torch_npu.npu.set_device(args.device)
+ npu_stream = torch_npu.npu.Stream()
+
+ model, model_config = get_pretrained_model(args.model)
+ sample_rate = model_config["sample_rate"]
+ sample_size = model_config["sample_size"]
+
+ model = model.to("npu").to(torch.float16).eval()
+
+ conditioning = [{
+ "prompt":"",
+ "seconds_start": 0,
+ "seconds_total": 0,
+ }]
+ total_time = 0
+ prompts_num = 0
+ average_time = 0
+ skip = 2
+ with os.fdopen(os.open(args.prompt_file, os.O_RDONLY), "r") as f:
+ for i, prompt in enumerate(f):
+ with torch.no_grad():
+ conditioning[0]["prompt"] = prompt
+ conditioning[0]["seconds_total"] = float(args.seconds_total[i]) if (len(args.seconds_total) > i) else 10.0
+
+ npu_stream.synchronize()
+ begin = time.time()
+ output = generate_diffusion_cond(
+ model,
+ steps=args.num_inference_steps,
+ cfg_scale=7,
+ conditioning=conditioning,
+ sample_size=sample_size,
+ sigma_min=0.3,
+ sigma_max=500,
+ sampler_type="dpmpp-3m-sde",
+ device="npu"
+ )
+ npu_stream.synchronize()
+ end = time.time()
+ if i > skip-1:
+ total_time += end - begin
+ prompts_num = i+1
+ waveform_start = int(conditioning[0]["seconds_start"] * sample_rate)
+ waveform_end = int(conditioning[0]["seconds_total"] * sample_rate)
+ output = output[:, :, waveform_start:waveform_end]
+ output = rearrange(output, "b d n -> d (b n)")
+ output = output.to(torch.float32).div(torch.max(torch.abs(output))).clamp(-1,1).mul(32767).to(torch.int16).cpu()
+ torchaudio.save(args.save_dir + "/audio_by_prompt" + str(prompts_num) + ".wav", output, sample_rate)
+ if prompts_num > skip:
+ average_time = total_time / (prompts_num-skip)
+ else:
+ print("Infer average time skip first two prompts, make sure prompts.txt has three more prompts")
+ print(f"Infer average time: {average_time:.3f}s\n")
+
+if __name__ == "__main__":
+ main()
+
+
+
+
+
+
+