From 21b457f953854d65810f2750250789bacb952a5e Mon Sep 17 00:00:00 2001 From: zhoufan2956 Date: Wed, 15 Jan 2025 16:33:34 +0800 Subject: [PATCH 1/6] add_stable_audio --- .../stable_audio_open/README.md | 152 +++ .../inference_stableaudio.py | 143 +++ .../stable_audio_open/prompts/prompts.txt | 6 + .../stable_audio_open/requirements.txt | 6 + .../stable_audio_open/stableaudio/__init__.py | 4 + .../stableaudio/layers/__init__.py | 4 + .../stableaudio/layers/attention.py | 374 ++++++++ .../stableaudio/layers/linear.py | 99 ++ .../stableaudio/models/__init__.py | 1 + .../models/stable_audio_transformer.py | 428 +++++++++ .../stableaudio/pipeline/__init__.py | 1 + .../pipeline/pipeline_stable_audio.py | 754 +++++++++++++++ .../stableaudio/schedulers/__init__.py | 1 + .../scheduling_cosine_dpmsolver_multistep.py | 572 +++++++++++ .../schedulers/scheduling_dpmsolver_sde.py | 70 ++ .../schedulers/scheduling_utils.py | 888 ++++++++++++++++++ 16 files changed, 3503 insertions(+) create mode 100644 MindIE/MindIE-Torch/built-in/foundation/stable_audio_open_1.0/stable_audio_open/README.md create mode 100644 MindIE/MindIE-Torch/built-in/foundation/stable_audio_open_1.0/stable_audio_open/inference_stableaudio.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/stable_audio_open_1.0/stable_audio_open/prompts/prompts.txt create mode 100644 MindIE/MindIE-Torch/built-in/foundation/stable_audio_open_1.0/stable_audio_open/requirements.txt create mode 100644 MindIE/MindIE-Torch/built-in/foundation/stable_audio_open_1.0/stable_audio_open/stableaudio/__init__.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/stable_audio_open_1.0/stable_audio_open/stableaudio/layers/__init__.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/stable_audio_open_1.0/stable_audio_open/stableaudio/layers/attention.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/stable_audio_open_1.0/stable_audio_open/stableaudio/layers/linear.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/stable_audio_open_1.0/stable_audio_open/stableaudio/models/__init__.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/stable_audio_open_1.0/stable_audio_open/stableaudio/models/stable_audio_transformer.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/stable_audio_open_1.0/stable_audio_open/stableaudio/pipeline/__init__.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/stable_audio_open_1.0/stable_audio_open/stableaudio/pipeline/pipeline_stable_audio.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/stable_audio_open_1.0/stable_audio_open/stableaudio/schedulers/__init__.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/stable_audio_open_1.0/stable_audio_open/stableaudio/schedulers/scheduling_cosine_dpmsolver_multistep.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/stable_audio_open_1.0/stable_audio_open/stableaudio/schedulers/scheduling_dpmsolver_sde.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/stable_audio_open_1.0/stable_audio_open/stableaudio/schedulers/scheduling_utils.py diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_audio_open_1.0/stable_audio_open/README.md b/MindIE/MindIE-Torch/built-in/foundation/stable_audio_open_1.0/stable_audio_open/README.md new file mode 100644 index 0000000000..389707e71f --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_audio_open_1.0/stable_audio_open/README.md @@ -0,0 +1,152 @@ +--- +pipeline_tag: text-to-audio +frameworks: + - PyTorch +license: apache-2.0 +library_name: openmind +hardwares: + - NPU +language: + - en +--- +## 一、准备运行环境 + + **表 1** 版本配套表 + + | 配套 | 版本 | 环境准备指导 | + | ----- | ----- |-----| + | Python | 3.10.2 | - | + | torch | 2.1.0 | - | + +### 1.1 获取CANN&MindIE安装包&环境准备 +- [800I A2](https://www.hiascend.com/developer/download/community/result?module=pt+ie+cann&product=4&model=32) +- [Duo卡](https://www.hiascend.com/developer/download/community/result?module=pt+ie+cann&product=2&model=17) +- [环境准备指导](https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/80RC2alpha002/softwareinst/instg/instg_0001.html) + +### 1.2 CANN安装 +```shell +# 增加软件包可执行权限,{version}表示软件版本号,{arch}表示CPU架构,{soc}表示昇腾AI处理器的版本。 +chmod +x ./Ascend-cann-toolkit_{version}_linux-{arch}.run +chmod +x ./Ascend-cann-kernels-{soc}_{version}_linux.run +# 校验软件包安装文件的一致性和完整性 +./Ascend-cann-toolkit_{version}_linux-{arch}.run --check +./Ascend-cann-kernels-{soc}_{version}_linux.run --check +# 安装 +./Ascend-cann-toolkit_{version}_linux-{arch}.run --install +./Ascend-cann-kernels-{soc}_{version}_linux.run --install + +# 设置环境变量 +source /usr/local/Ascend/ascend-toolkit/set_env.sh +``` + + +### 1.3 MindIE安装 +```shell +# 增加软件包可执行权限,{version}表示软件版本号,{arch}表示CPU架构。 +chmod +x ./Ascend-mindie_${version}_linux-${arch}.run +./Ascend-mindie_${version}_linux-${arch}.run --check + +# 方式一:默认路径安装 +./Ascend-mindie_${version}_linux-${arch}.run --install +# 设置环境变量 +cd /usr/local/Ascend/mindie && source set_env.sh + +# 方式二:指定路径安装 +./Ascend-mindie_${version}_linux-${arch}.run --install-path=${AieInstallPath} +# 设置环境变量 +cd ${AieInstallPath}/mindie && source set_env.sh +``` + +### 1.4 Torch_npu安装 +安装pytorch框架 版本2.1.0 +[安装包下载](https://download.pytorch.org/whl/cpu/torch/) + +使用pip安装 +```shell +# {version}表示软件版本号,{arch}表示CPU架构。 +pip install torch-${version}-cp310-cp310-linux_${arch}.whl +``` +下载 pytorch_v{pytorchversion}_py{pythonversion}.tar.gz +```shell +tar -xzvf pytorch_v{pytorchversion}_py{pythonversion}.tar.gz +# 解压后,会有whl包 +pip install torch_npu-{pytorchversion}.xxxx.{arch}.whl +``` + +### 1.5 安装mindspeed +```shell +# 下载mindspeed源码仓: +git clone https://gitee.com/ascend/MindSpeed.git +# 执行如下命令进行安装: +pip install -e MindSpeed +``` + +## 二、下载本仓库 + +### 2.1 下载到本地 +```shell +git clone https://modelers.cn/MindIE/stable_audio_open_1.0.git +``` + +### 2.2 依赖安装 +```bash +pip3 install -r requirements.txt +apt-get update +apt-get install libsndfile1 +``` + +## 三、Stable-Audio-Open-1.0 使用 + +### 3.1 权重及配置文件说明 +stable-audio-open-1.0权重链接: +```shell +https://huggingface.co/stabilityai/stable-audio-open-1.0/tree/main +``` + +### 3.2 单卡功能测试 +设置权重路径 +```shell +model_base='./stable-audio-open-1.0' +``` +执行命令: +```shell +# 不使用DiTCache策略 +python3 inference_stableaudio.py \ + --model ${model_base} \ + --prompt_file ./prompts/prompts.txt \ + --num_inference_steps 100 \ + --audio_end_in_s 10 10 47 \ + --save_dir ./results \ + --seed 1 \ + --device 0 + +# 使用DiTCache策略 +python3 inference_stableaudio.py \ + --model ${model_base} \ + --prompt_file ./prompts/prompts.txt \ + --num_inference_steps 100 \ + --audio_end_in_s 10 10 47 \ + --save_dir ./results \ + --seed 1 \ + --device 0 \ + --use_cache +``` +参数说明: +- --model:模型权重路径。 +- --prompt_file:提示词文件。 +- --num_inference_steps: 语音生成迭代次数。 +- --audio_end_in_s:生成语音的时长,如不输入则默认生成10s。 +- --save_dir:生成语音的存放目录。 +- --seed:设置随机种子,不指定时默认使用随机种子。 +- --device:推理设备ID。 +- --use_cache: 【可选】使用DiTCache策略。 + +执行完成后在`./results`目录下生成推理语音,语音生成顺序与文本中prompt顺序保持一致,并在终端显示推理时间。 + +### 3.2 模型推理性能 + +性能参考下列数据。 + +| 硬件形态 | 迭代次数 | 平均耗时(w/o DiTCache)| 平均耗时(with DiTCache)| +| :------: |:----:|:----:|:----:| +| Atlas 800I A2 (32G) | 100 | 5.645s | 4.991s | \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_audio_open_1.0/stable_audio_open/inference_stableaudio.py b/MindIE/MindIE-Torch/built-in/foundation/stable_audio_open_1.0/stable_audio_open/inference_stableaudio.py new file mode 100644 index 0000000000..08801a0082 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_audio_open_1.0/stable_audio_open/inference_stableaudio.py @@ -0,0 +1,143 @@ +import torch +import torch_npu +import time +import json +import os +import argparse +import soundfile as sf +from safetensors.torch import load_file + +from transformers import T5TokenizerFast, T5EncoderModel +from stableaudio import ( + StableAudioPipeline, + StableAudioDiTModel, + StableAudioProjectionModel, + CosineDPMSolverMultistepScheduler, + AutoencoderOobleck, +) + + +def parse_arguments(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--prompt_file", + type=str, + default="./prompts/prompts.txt", + help="The prompts file to guide audio generation.", + ) + parser.add_argument( + "--negative_prompt", + type=str, + default="", + help="The prompt or prompts to guide what to not include in audio generation.", + ) + parser.add_argument( + "--num_inference_steps", + type=int, + default=100, + help="The number of denoising steps. More denoising steps usually lead to a higher quality audio at the expense of slower inference.", + ) + parser.add_argument( + "--model", + type=str, + default="./stable-audio-open-1.0", + help="The path of stable-audio-open-1.0.", + ) + parser.add_argument( + "--audio_end_in_s", + nargs='+', + default=[10], + help="Audio end index in seconds.", + ) + parser.add_argument( + "--device", + type=int, + default=0, + help="NPU device id.", + ) + parser.add_argument( + "--save_dir", + type=str, + default="./results", + help="Path to save result audio files.", + ) + parser.add_argument( + "--seed", + type=int, + default=-1, + help="Random seed, default 1.", + ) + parser.add_argument( + "--use_cache", + action="store_true", + help="turn on dit cache or not.", + ) + return parser.parse_args() + +def main(): + args = parse_arguments() + save_dir = args.save_dir + if not os.path.exists(save_dir): + os.makedirs(save_dir) + + torch_npu.npu.set_device(args.device) + if args.seed != -1: + torch.manual_seed(args.seed) + latents = torch.randn(1, 64, 1024, dtype=torch.float16,device="cpu") + with open(args.model + "/vae/config.json", "r", encoding="utf-8") as reader: + data = reader.read() + json_data = json.loads(data) + init_dict = {key: json_data[key] for key in json_data} + vae = AutoencoderOobleck(**init_dict) + vae.load_state_dict(load_file(args.model + "/vae/diffusion_pytorch_model.safetensors"), strict=False) + + tokenizer = T5TokenizerFast.from_pretrained(args.model + "/tokenizer") + text_encoder = T5EncoderModel.from_pretrained(args.model + "/text_encoder") + projection_model = StableAudioProjectionModel.from_pretrained(args.model + "/projection_model") + audio_dit = StableAudioDiTModel.from_pretrained(args.model + "/transformer") + scheduler = CosineDPMSolverMultistepScheduler.from_pretrained(args.model + "/scheduler") + + npu_stream = torch_npu.npu.Stream() + vae = vae.to("npu").to(torch.float16).eval() + text_encoder = text_encoder.to("npu").to(torch.float16).eval() + projection_model = projection_model.to("npu").to(torch.float16).eval() + audio_dit = audio_dit.to("npu").to(torch.float16).eval() + + pipe = StableAudioPipeline(vae=vae, tokenizer=tokenizer, text_encoder=text_encoder, + projection_model=projection_model, transformer=audio_dit, scheduler=scheduler) + pipe.to("npu") + + total_time = 0 + prompts_num = 0 + average_time = 0 + skip = 2 + with os.fdopen(os.open(args.prompt_file, os.O_RDONLY), "r") as f: + for i, prompt in enumerate(f): + with torch.no_grad(): + npu_stream.synchronize() + audio_end_in_s = float(args.audio_end_in_s[i]) if (len(args.audio_end_in_s) > i) else 10.0 + begin = time.time() + audio = pipe( + prompt=prompt, + negative_prompt=args.negative_prompt, + num_inference_steps=args.num_inference_steps, + latents=latents.to("npu"), + audio_end_in_s=audio_end_in_s, + use_cache=args.use_cache, + ).audios + npu_stream.synchronize() + end = time.time() + if i > skip - 1: + total_time += end - begin + prompts_num = i+1 + output = audio[0].T.float().cpu().numpy() + sf.write(args.save_dir + "/audio_by_prompt" + str(prompts_num) + ".wav", output, pipe.vae.sampling_rate) + if prompts_num > skip: + average_time = total_time / (prompts_num-skip) + else: + raise ValueError("Infer average time skip first two prompts, ensure that prompts.txt \ + contains more than three prompts") + print(f"Infer average time: {average_time:.3f}s\n") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_audio_open_1.0/stable_audio_open/prompts/prompts.txt b/MindIE/MindIE-Torch/built-in/foundation/stable_audio_open_1.0/stable_audio_open/prompts/prompts.txt new file mode 100644 index 0000000000..e977cc3ad7 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_audio_open_1.0/stable_audio_open/prompts/prompts.txt @@ -0,0 +1,6 @@ +Berlin techno, rave, drum machine, kick, ARP synthesizer, dark, moody, hypnotic, evolving, 135BPM. LOOP. +Uplifting acoustic loop. 120 BPM. +Disco, Driving Drum Machine, Synthesizer, Bass, Piano, Guitars, Instrumental, Clubby, Euphoric, Chicago, New York, 115 BPM. +Warm arpeggios on an analog synthesizer with a gradually rising filter cutoff and a reverb tail. +Blackbird song, summer, dusj in the forest. +Rock beat played in a treated studio, session drumming on an acoustic kit. \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_audio_open_1.0/stable_audio_open/requirements.txt b/MindIE/MindIE-Torch/built-in/foundation/stable_audio_open_1.0/stable_audio_open/requirements.txt new file mode 100644 index 0000000000..6353806d3f --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_audio_open_1.0/stable_audio_open/requirements.txt @@ -0,0 +1,6 @@ +torch==2.1.0 +torchsde==0.2.6 +diffusers==0.30.0 +transformers==4.40.0 +soundfile==0.12.1 +torch_npu==2.1.0.post6 \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_audio_open_1.0/stable_audio_open/stableaudio/__init__.py b/MindIE/MindIE-Torch/built-in/foundation/stable_audio_open_1.0/stable_audio_open/stableaudio/__init__.py new file mode 100644 index 0000000000..2e896ae145 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_audio_open_1.0/stable_audio_open/stableaudio/__init__.py @@ -0,0 +1,4 @@ +from diffusers.models.autoencoders.autoencoder_oobleck import AutoencoderOobleck +from .pipeline import StableAudioPipeline, StableAudioProjectionModel +from .models import StableAudioDiTModel +from .schedulers.scheduling_cosine_dpmsolver_multistep import CosineDPMSolverMultistepScheduler \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_audio_open_1.0/stable_audio_open/stableaudio/layers/__init__.py b/MindIE/MindIE-Torch/built-in/foundation/stable_audio_open_1.0/stable_audio_open/stableaudio/layers/__init__.py new file mode 100644 index 0000000000..f2f044203d --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_audio_open_1.0/stable_audio_open/stableaudio/layers/__init__.py @@ -0,0 +1,4 @@ +from .attention import ( + Attention, + StableAudioAttnProcessor2_0, +) \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_audio_open_1.0/stable_audio_open/stableaudio/layers/attention.py b/MindIE/MindIE-Torch/built-in/foundation/stable_audio_open_1.0/stable_audio_open/stableaudio/layers/attention.py new file mode 100644 index 0000000000..1db4677c36 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_audio_open_1.0/stable_audio_open/stableaudio/layers/attention.py @@ -0,0 +1,374 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import inspect +from typing import Any, Dict, List, Optional, Tuple + +import torch +import torch_npu +import torch.nn.functional as F +from torch import nn + +from diffusers.utils import logging +from diffusers.models.attention_processor import AttnProcessor, AttnProcessor2_0 +from mindspeed.ops.npu_rotary_position_embedding import npu_rotary_position_embedding +from .linear import QKVLinear + +logger = logging.get_logger(__name__) + + +class Attention(nn.Module): + r""" + A cross attention layer. + + Parameters: + query_dim (`int`): + The number of channels in the query. + cross_attention_dim (`int`, *optional*): + The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`. + heads (`int`, *optional*, defaults to 8): + The number of heads to use for multi-head attention. + kv_heads (`int`, *optional*, defaults to `None`): + The number of key and value heads to use for multi-head attention. Defaults to `heads`. If + `kv_heads=heads`, the model will use Multi Head Attention (MHA), if `kv_heads=1` the model will use Multi + Query Attention (MQA) otherwise GQA is used. + dim_head (`int`, *optional*, defaults to 64): + The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): + The dropout probability to use. + bias (`bool`, *optional*, defaults to False): + Set to `True` for the query, key, and value linear layers to contain a bias parameter. + upcast_attention (`bool`, *optional*, defaults to False): + Set to `True` to upcast the attention computation to `float32`. + upcast_softmax (`bool`, *optional*, defaults to False): + Set to `True` to upcast the softmax computation to `float32`. + cross_attention_norm (`str`, *optional*, defaults to `None`): + The type of normalization to use for the cross attention. Can be `None`, `layer_norm`, or `group_norm`. + cross_attention_norm_num_groups (`int`, *optional*, defaults to 32): + The number of groups to use for the group norm in the cross attention. + added_kv_proj_dim (`int`, *optional*, defaults to `None`): + The number of channels to use for the added key and value projections. If `None`, no projection is used. + norm_num_groups (`int`, *optional*, defaults to `None`): + The number of groups to use for the group norm in the attention. + spatial_norm_dim (`int`, *optional*, defaults to `None`): + The number of channels to use for the spatial normalization. + out_bias (`bool`, *optional*, defaults to `True`): + Set to `True` to use a bias in the output linear layer. + scale_qk (`bool`, *optional*, defaults to `True`): + Set to `True` to scale the query and key by `1 / sqrt(dim_head)`. + only_cross_attention (`bool`, *optional*, defaults to `False`): + Set to `True` to only use cross attention and not added_kv_proj_dim. Can only be set to `True` if + `added_kv_proj_dim` is not `None`. + eps (`float`, *optional*, defaults to 1e-5): + An additional value added to the denominator in group normalization that is used for numerical stability. + rescale_output_factor (`float`, *optional*, defaults to 1.0): + A factor to rescale the output by dividing it with this value. + residual_connection (`bool`, *optional*, defaults to `False`): + Set to `True` to add the residual connection to the output. + _from_deprecated_attn_block (`bool`, *optional*, defaults to `False`): + Set to `True` if the attention block is loaded from a deprecated state dict. + processor (`AttnProcessor`, *optional*, defaults to `None`): + The attention processor to use. If `None`, defaults to `AttnProcessor2_0` if `torch 2.x` is used and + `AttnProcessor` otherwise. + """ + + def __init__( + self, + query_dim: int, + cross_attention_dim: Optional[int] = None, + heads: int = 8, + kv_heads: Optional[int] = None, + dim_head: int = 64, + dropout: float = 0.0, + bias: bool = False, + upcast_attention: bool = False, + upcast_softmax: bool = False, + cross_attention_norm: Optional[str] = None, + cross_attention_norm_num_groups: int = 32, + qk_norm: Optional[str] = None, + added_kv_proj_dim: Optional[int] = None, + added_proj_bias: Optional[bool] = True, + norm_num_groups: Optional[int] = None, + spatial_norm_dim: Optional[int] = None, + out_bias: bool = True, + scale_qk: bool = True, + only_cross_attention: bool = False, + eps: float = 1e-5, + rescale_output_factor: float = 1.0, + residual_connection: bool = False, + _from_deprecated_attn_block: bool = False, + processor: Optional["AttnProcessor"] = None, + out_dim: int = None, + context_pre_only=None, + pre_only=False, + ): + super().__init__() + + self.inner_dim = out_dim if out_dim is not None else dim_head * heads + self.inner_kv_dim = self.inner_dim if kv_heads is None else dim_head * kv_heads + self.query_dim = query_dim + self.use_bias = bias + self.is_cross_attention = cross_attention_dim is not None + self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim + self.upcast_attention = upcast_attention + self.upcast_softmax = upcast_softmax + self.rescale_output_factor = rescale_output_factor + self.residual_connection = residual_connection + self.dropout = dropout + self.fused_projections = False + self.out_dim = out_dim if out_dim is not None else query_dim + self.context_pre_only = context_pre_only + self.pre_only = pre_only + + # we make use of this private variable to know whether this class is loaded + # with an deprecated state dict so that we can convert it on the fly + self._from_deprecated_attn_block = _from_deprecated_attn_block + + self.scale_qk = scale_qk + self.scale = dim_head**-0.5 if self.scale_qk else 1.0 + + self.heads = out_dim // dim_head if out_dim is not None else heads + self.dim_head = dim_head + # for slice_size > 0 the attention score computation + # is split across the batch axis to save memory + # You can set slice_size with `set_attention_slice` + self.sliceable_head_dim = heads + + self.added_kv_proj_dim = added_kv_proj_dim + self.only_cross_attention = only_cross_attention + + if self.added_kv_proj_dim is None and self.only_cross_attention: + raise ValueError( + "`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`." + ) + + if norm_num_groups is not None: + self.group_norm = nn.GroupNorm(num_channels=query_dim, num_groups=norm_num_groups, eps=eps, affine=True) + else: + self.group_norm = None + + self.spatial_norm = None + + if qk_norm is None: + self.norm_q = None + self.norm_k = None + elif qk_norm == "layer_norm": + self.norm_q = nn.LayerNorm(dim_head, eps=eps) + self.norm_k = nn.LayerNorm(dim_head, eps=eps) + else: + raise ValueError(f"unknown qk_norm: {qk_norm}. Should be None or 'layer_norm'") + + if cross_attention_norm is None: + self.norm_cross = None + elif cross_attention_norm == "layer_norm": + self.norm_cross = nn.LayerNorm(self.cross_attention_dim) + elif cross_attention_norm == "group_norm": + if self.added_kv_proj_dim is not None: + # The given `encoder_hidden_states` are initially of shape + # (batch_size, seq_len, added_kv_proj_dim) before being projected + # to (batch_size, seq_len, cross_attention_dim). The norm is applied + # before the projection, so we need to use `added_kv_proj_dim` as + # the number of channels for the group norm. + norm_cross_num_channels = added_kv_proj_dim + else: + norm_cross_num_channels = self.cross_attention_dim + + self.norm_cross = nn.GroupNorm( + num_channels=norm_cross_num_channels, num_groups=cross_attention_norm_num_groups, eps=1e-5, affine=True + ) + else: + raise ValueError( + f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'" + ) + + self.to_qkv = QKVLinear( + attention_dim=query_dim, + hidden_size=self.inner_dim, + qkv_bias=self.use_bias, + cross_attention_dim=cross_attention_dim, + cross_hidden_size=self.inner_kv_dim, + ) + + self.added_proj_bias = added_proj_bias + if self.added_kv_proj_dim is not None: + self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias) + self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias) + if self.context_pre_only is not None: + self.add_q_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) + + if not self.pre_only: + self.to_out = nn.ModuleList([]) + self.to_out.append(nn.Linear(self.inner_dim, self.out_dim, bias=out_bias)) + self.to_out.append(nn.Dropout(dropout)) + + if self.context_pre_only is not None and not self.context_pre_only: + self.to_add_out = nn.Linear(self.inner_dim, self.out_dim, bias=out_bias) + + if qk_norm is not None and added_kv_proj_dim is not None: + if qk_norm == "fp32_layer_norm": + self.norm_added_q = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps) + self.norm_added_k = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps) + elif qk_norm == "rms_norm": + self.norm_added_q = RMSNorm(dim_head, eps=eps) + self.norm_added_k = RMSNorm(dim_head, eps=eps) + else: + self.norm_added_q = None + self.norm_added_k = None + + # set attention processor + # We use the AttnProcessor2_0 by default when torch 2.x is used which uses + # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention + # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1 + if processor is None: + processor = ( + AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor() + ) + self.set_processor(processor) + + def set_processor(self, processor: "AttnProcessor") -> None: + r""" + Set the attention processor to use. + + Args: + processor (`AttnProcessor`): + The attention processor to use. + """ + # if current processor is in `self._modules` and if passed `processor` is not, we need to + # pop `processor` from `self._modules` + if ( + hasattr(self, "processor") + and isinstance(self.processor, torch.nn.Module) + and not isinstance(processor, torch.nn.Module) + ): + logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}") + self._modules.pop("processor") + + self.processor = processor + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + **cross_attention_kwargs, + ) -> torch.Tensor: + r""" + The forward method of the `Attention` class. + + Args: + hidden_states (`torch.Tensor`): + The hidden states of the query. + encoder_hidden_states (`torch.Tensor`, *optional*): + The hidden states of the encoder. + attention_mask (`torch.Tensor`, *optional*): + The attention mask to use. If `None`, no mask is applied. + **cross_attention_kwargs: + Additional keyword arguments to pass along to the cross attention. + + Returns: + `torch.Tensor`: The output of the attention layer. + """ + # The `Attention` class can call different attention processors / attention functions + # here we simply pass along all tensors to the selected processor class + # For standard processors that are defined here, `**cross_attention_kwargs` is empty + + attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys()) + quiet_attn_parameters = {"ip_adapter_masks"} + unused_kwargs = [ + k for k, _ in cross_attention_kwargs.items() if k not in attn_parameters and k not in quiet_attn_parameters + ] + if len(unused_kwargs) > 0: + logger.warning( + f"cross_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored." + ) + cross_attention_kwargs = {k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters} + + return self.processor( + self, + hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + **cross_attention_kwargs, + ) + + +class StableAudioAttnProcessor2_0: + r""" + Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is + used in the Stable Audio model. It applies rotary embedding on query and key vector, and allows MHA, GQA or MQA. + """ + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "StableAudioAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." + ) + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + rotary_emb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + query, key, value = attn.to_qkv(hidden_states, encoder_hidden_states) + + head_dim = attn.dim_head + kv_heads = attn.inner_kv_dim // head_dim + + query = query.view(batch_size, -1, attn.heads, head_dim) + key = key.view(batch_size, -1, kv_heads, head_dim) + value = value.view(batch_size, -1, kv_heads, head_dim) + + if kv_heads != attn.heads: + # if GQA or MQA, repeat the key/value heads to reach the number of query heads. + heads_per_kv_head = attn.heads // kv_heads + key = key.unsqueeze(3).repeat(1, 1, 1, heads_per_kv_head, 1) + key = key.view(batch_size, sequence_length, attn.heads, head_dim) + value = value.unsqueeze(3).repeat(1, 1, 1, heads_per_kv_head, 1) + value = value.view(batch_size, sequence_length, attn.heads, head_dim) + + # Apply RoPE if needed + if rotary_emb is not None: + cos, sin = rotary_emb + query_to_rotate, query_unrotated = torch.chunk(query, 2, 3) + query_rotated = npu_rotary_position_embedding(query_to_rotate, cos, sin) + query = torch.cat((query_rotated, query_unrotated), dim=-1) + + if not attn.is_cross_attention: + key_to_rotate, key_unrotated = torch.chunk(key, 2, 3) + key_rotated = npu_rotary_position_embedding(key_to_rotate, cos, sin) + key = torch.cat((key_rotated, key_unrotated), dim=-1) + + hidden_states = torch_npu.npu_fusion_attention( + query, key, value, + atten_mask=attention_mask, + input_layout='BSND', + scale=attn.scale, + head_num=attn.heads, + )[0] + + hidden_states = hidden_states.reshape(batch_size, -1, attn.heads * head_dim) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + return hidden_states \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_audio_open_1.0/stable_audio_open/stableaudio/layers/linear.py b/MindIE/MindIE-Torch/built-in/foundation/stable_audio_open_1.0/stable_audio_open/stableaudio/layers/linear.py new file mode 100644 index 0000000000..3e8ee84fc8 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_audio_open_1.0/stable_audio_open/stableaudio/layers/linear.py @@ -0,0 +1,99 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch +import torch.nn as nn +import torch_npu + + +class QKVLinear(nn.Module): + def __init__(self, attention_dim, hidden_size, qkv_bias=True, cross_attention_dim=None, cross_hidden_size=None, + device=None, dtype=None): + super(QKVLinear, self).__init__() + self.attention_dim = attention_dim + self.hidden_size = hidden_size + + self.cross_attention_dim = cross_attention_dim + self.cross_hidden_size = self.hidden_size if cross_hidden_size is None else cross_hidden_size + self.qkv_bias = qkv_bias + + factory_kwargs = {"device": device, "dtype": dtype} + + if cross_attention_dim is None: + self.weight = nn.Parameter(torch.empty([self.attention_dim, 3 * self.hidden_size], **factory_kwargs)) + if self.qkv_bias: + self.bias = nn.Parameter(torch.empty([3 * self.hidden_size], **factory_kwargs)) + else: + self.q_weight = nn.Parameter(torch.empty([self.attention_dim, self.hidden_size], **factory_kwargs)) + self.kv_weight = nn.Parameter( + torch.empty([self.cross_attention_dim, 2 * self.cross_hidden_size], **factory_kwargs)) + + if self.qkv_bias: + self.q_bias = nn.Parameter(torch.empty([self.hidden_size], **factory_kwargs)) + self.kv_bias = nn.Parameter(torch.empty([2 * self.cross_hidden_size], **factory_kwargs)) + + def forward(self, hidden_states, encoder_hidden_states=None): + + if self.cross_attention_dim is None: + if not self.qkv_bias: + qkv = torch.matmul(hidden_states, self.weight) + else: + qkv = torch.addmm( + self.bias, + hidden_states.view(hidden_states.size(0) * hidden_states.size(1), hidden_states.size(2)), + self.weight, + beta=1, + alpha=1 + ) + + batch, seqlen, _ = hidden_states.shape + qkv_shape = (batch, seqlen, 3, -1) + qkv = qkv.view(qkv_shape) + q, k, v = qkv.unbind(2) + + else: + if not self.qkv_bias: + q = torch.matmul(hidden_states, self.q_weight) + kv = torch.matmul(encoder_hidden_states, self.kv_weight) + else: + q = torch.addmm( + self.q_bias, + hidden_states.view(hidden_states.size(0) * hidden_states.size(1), hidden_states.size(2)), + self.q_weight, + beta=1, + alpha=1 + ) + kv = torch.addmm( + self.kv_bias, + encoder_hidden_states.view( + encoder_hidden_states.size(0) * encoder_hidden_states.size(1), + encoder_hidden_states.size(2)), + self.kv_weight, + beta=1, + alpha=1 + ) + + batch, q_seqlen, _ = hidden_states.shape + q = q.view(batch, q_seqlen, -1) + + batch, kv_seqlen, _ = encoder_hidden_states.shape + kv_shape = (batch, kv_seqlen, 2, -1) + + kv = kv.view(kv_shape) + k, v = kv.unbind(2) + + return q, k, v diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_audio_open_1.0/stable_audio_open/stableaudio/models/__init__.py b/MindIE/MindIE-Torch/built-in/foundation/stable_audio_open_1.0/stable_audio_open/stableaudio/models/__init__.py new file mode 100644 index 0000000000..2c132c4286 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_audio_open_1.0/stable_audio_open/stableaudio/models/__init__.py @@ -0,0 +1 @@ +from .stable_audio_transformer import StableAudioDiTModel \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_audio_open_1.0/stable_audio_open/stableaudio/models/stable_audio_transformer.py b/MindIE/MindIE-Torch/built-in/foundation/stable_audio_open_1.0/stable_audio_open/stableaudio/models/stable_audio_transformer.py new file mode 100644 index 0000000000..cd8a822c1b --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_audio_open_1.0/stable_audio_open/stableaudio/models/stable_audio_transformer.py @@ -0,0 +1,428 @@ +# Copyright 2024 Stability AI and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Optional, Union +from collections import OrderedDict + +import numpy as np +import torch +import torch.nn as nn +import torch.utils.checkpoint + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.models.attention import FeedForward +from diffusers.models.transformers.transformer_2d import Transformer2DModelOutput +from diffusers.models.modeling_utils import ModelMixin +from diffusers.utils import logging + +from ..layers.attention import ( + Attention, + StableAudioAttnProcessor2_0, +) + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class StableAudioGaussianFourierProjection(nn.Module): + """Gaussian Fourier embeddings for noise levels.""" + + # Copied from diffusers.models.embeddings.GaussianFourierProjection.__init__ + def __init__( + self, embedding_size: int = 256, scale: float = 1.0, set_W_to_weight=True, log=True, flip_sin_to_cos=False + ): + super().__init__() + self.weight = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False) + self.log = log + self.flip_sin_to_cos = flip_sin_to_cos + + if set_W_to_weight: + # to delete later + del self.weight + self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False) + self.weight = self.W + del self.W + + def forward(self, x): + if self.log: + x = torch.log(x) + + x_proj = 2 * np.pi * x[:, None] @ self.weight[None, :] + + if self.flip_sin_to_cos: + out = torch.cat([torch.cos(x_proj), torch.sin(x_proj)], dim=-1) + else: + out = torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1) + return out + + +class StableAudioDiTBlock(nn.Module): + r""" + Transformer block used in Stable Audio model (https://github.com/Stability-AI/stable-audio-tools). Allow skip + connection and QKNorm + + Parameters: + dim (`int`): The number of channels in the input and output. + num_attention_heads (`int`): The number of heads to use for the query states. + num_key_value_attention_heads (`int`): The number of heads to use for the key and value states. + attention_head_dim (`int`): The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. + upcast_attention (`bool`, *optional*): + Whether to upcast the attention computation to float32. This is useful for mixed precision training. + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + num_key_value_attention_heads: int, + attention_head_dim: int, + dropout=0.0, + cross_attention_dim: Optional[int] = None, + upcast_attention: bool = False, + norm_eps: float = 1e-5, + ff_inner_dim: Optional[int] = None, + ): + super().__init__() + # Define 3 blocks. Each block has its own normalization layer. + # 1. Self-Attn + self.norm1 = nn.LayerNorm(dim, elementwise_affine=True, eps=norm_eps) + self.attn1 = Attention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=False, + upcast_attention=upcast_attention, + out_bias=False, + processor=StableAudioAttnProcessor2_0(), + ) + + # 2. Cross-Attn + self.norm2 = nn.LayerNorm(dim, norm_eps, True) + + self.attn2 = Attention( + query_dim=dim, + cross_attention_dim=cross_attention_dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + kv_heads=num_key_value_attention_heads, + dropout=dropout, + bias=False, + upcast_attention=upcast_attention, + out_bias=False, + processor=StableAudioAttnProcessor2_0(), + ) # is self-attn if encoder_hidden_states is none + + # 3. Feed-forward + self.norm3 = nn.LayerNorm(dim, norm_eps, True) + self.ff = FeedForward( + dim, + dropout=dropout, + activation_fn="swiglu", + final_dropout=False, + inner_dim=ff_inner_dim, + bias=True, + ) + + # let chunk size default to None + self._chunk_size = None + self._chunk_dim = 0 + + def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0): + # Sets chunk feed-forward + self._chunk_size = chunk_size + self._chunk_dim = dim + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + rotary_embedding: Optional[torch.FloatTensor] = None, + ) -> torch.Tensor: + # Notice that normalization is always applied before the real computation in the following blocks. + # 0. Self-Attention + norm_hidden_states = self.norm1(hidden_states) + + attn_output = self.attn1( + norm_hidden_states, + attention_mask=attention_mask, + rotary_emb=rotary_embedding, + ) + + hidden_states = attn_output + hidden_states + + # 2. Cross-Attention + norm_hidden_states = self.norm2(hidden_states) + + attn_output = self.attn2( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + ) + hidden_states = attn_output + hidden_states + + # 3. Feed-forward + norm_hidden_states = self.norm3(hidden_states) + ff_output = self.ff(norm_hidden_states) + + hidden_states = ff_output + hidden_states + + return hidden_states + + +class StableAudioDiTModel(ModelMixin, ConfigMixin): + """ + The Diffusion Transformer model introduced in Stable Audio. + + Reference: https://github.com/Stability-AI/stable-audio-tools + + Parameters: + sample_size ( `int`, *optional*, defaults to 1024): The size of the input sample. + in_channels (`int`, *optional*, defaults to 64): The number of channels in the input. + num_layers (`int`, *optional*, defaults to 24): The number of layers of Transformer blocks to use. + attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head. + num_attention_heads (`int`, *optional*, defaults to 24): The number of heads to use for the query states. + num_key_value_attention_heads (`int`, *optional*, defaults to 12): + The number of heads to use for the key and value states. + out_channels (`int`, defaults to 64): Number of output channels. + cross_attention_dim ( `int`, *optional*, defaults to 768): Dimension of the cross-attention projection. + time_proj_dim ( `int`, *optional*, defaults to 256): Dimension of the timestep inner projection. + global_states_input_dim ( `int`, *optional*, defaults to 1536): + Input dimension of the global hidden states projection. + cross_attention_input_dim ( `int`, *optional*, defaults to 768): + Input dimension of the cross-attention projection + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + sample_size: int = 1024, + in_channels: int = 64, + num_layers: int = 24, + attention_head_dim: int = 64, + num_attention_heads: int = 24, + num_key_value_attention_heads: int = 12, + out_channels: int = 64, + cross_attention_dim: int = 768, + time_proj_dim: int = 256, + global_states_input_dim: int = 1536, + cross_attention_input_dim: int = 768, + ): + super().__init__() + + self.cache_block_start = 11 + self.cache_step_interval = 2 + self.cache_num_blocks = 9 + self.cache_step_start = 5 + + self.num_layers = num_layers + + self.sample_size = sample_size + self.out_channels = out_channels + self.inner_dim = num_attention_heads * attention_head_dim + self.init_dtype = self.dtype + self.time_proj = StableAudioGaussianFourierProjection( + embedding_size=time_proj_dim // 2, + flip_sin_to_cos=True, + log=False, + set_W_to_weight=False, + ) + + self.timestep_proj = nn.Sequential( + nn.Linear(time_proj_dim, self.inner_dim, bias=True), + nn.SiLU(), + nn.Linear(self.inner_dim, self.inner_dim, bias=True), + ) + + self.global_proj = nn.Sequential( + nn.Linear(global_states_input_dim, self.inner_dim, bias=False), + nn.SiLU(), + nn.Linear(self.inner_dim, self.inner_dim, bias=False), + ) + + self.cross_attention_proj = nn.Sequential( + nn.Linear(cross_attention_input_dim, cross_attention_dim, bias=False), + nn.SiLU(), + nn.Linear(cross_attention_dim, cross_attention_dim, bias=False), + ) + + self.preprocess_conv = nn.Conv1d(in_channels, in_channels, 1, bias=False) + self.proj_in = nn.Linear(in_channels, self.inner_dim, bias=False) + + self.transformer_blocks = nn.ModuleList( + [ + StableAudioDiTBlock( + dim=self.inner_dim, + num_attention_heads=num_attention_heads, + num_key_value_attention_heads=num_key_value_attention_heads, + attention_head_dim=attention_head_dim, + cross_attention_dim=cross_attention_dim, + ) + for i in range(num_layers) + ] + ) + + self.proj_out = nn.Linear(self.inner_dim, self.out_channels, bias=False) + self.postprocess_conv = nn.Conv1d(self.out_channels, self.out_channels, 1, bias=False) + + self.gradient_checkpointing = False + + def forward( + self, + step_id, + hidden_states: torch.FloatTensor, + timestep: torch.LongTensor = None, + encoder_hidden_states: torch.FloatTensor = None, + global_hidden_states: torch.FloatTensor = None, + rotary_embedding: torch.FloatTensor = None, + return_dict: bool = True, + attention_mask: Optional[torch.LongTensor] = None, + encoder_attention_mask: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = False, + ) -> Union[torch.FloatTensor, Transformer2DModelOutput]: + """ + The [`StableAudioDiTModel`] forward method. + + Args: + hidden_states (`torch.FloatTensor` of shape `(batch size, in_channels, sequence_len)`): + Input `hidden_states`. + timestep ( `torch.LongTensor`): + Used to indicate denoising step. + encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, encoder_sequence_len, cross_attention_input_dim)`): + Conditional embeddings (embeddings computed from the input conditions such as prompts) to use. + global_hidden_states (`torch.FloatTensor` of shape `(batch size, global_sequence_len, global_states_input_dim)`): + Global embeddings that will be prepended to the hidden states. + rotary_embedding (`torch.Tensor`): + The rotary embeddings to apply on query and key tensors during attention calculation. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain + tuple. + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_len)`, *optional*): + Mask to avoid performing attention on padding token indices, formed by concatenating the attention + masks + for the two text encoders together. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + encoder_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_len)`, *optional*): + Mask to avoid performing attention on padding token cross-attention indices, formed by concatenating + the attention masks + for the two text encoders together. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + Returns: + If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a + `tuple` where the first element is the sample tensor. + """ + cross_attention_hidden_states = self.cross_attention_proj(encoder_hidden_states) + global_hidden_states = self.global_proj(global_hidden_states) + time_hidden_states = self.timestep_proj(self.time_proj(timestep.to(self.init_dtype))) + + global_hidden_states = global_hidden_states + time_hidden_states.unsqueeze(1) + + hidden_states = self.preprocess_conv(hidden_states) + hidden_states + # (batch_size, dim, sequence_length) -> (batch_size, sequence_length, dim) + hidden_states = hidden_states.transpose(1, 2) + + hidden_states = self.proj_in(hidden_states) + + # prepend global states to hidden states + hidden_states = torch.cat([global_hidden_states, hidden_states], dim=-2) + if attention_mask is not None: + prepend_mask = torch.ones((hidden_states.shape[0], 1), device=hidden_states.device, dtype=torch.bool) + attention_mask = torch.cat([prepend_mask, attention_mask], dim=-1) + + if not use_cache or (use_cache and step_id < self.cache_step_start): + hidden_states = self._transformer_blocks_forward( + hidden_states=hidden_states, + encoder_hidden_states=cross_attention_hidden_states, + rotary_embedding=rotary_embedding, + start_id=0, + end_id=self.num_layers, + ) + else: + hidden_states = self._transformer_blocks_forward( + hidden_states=hidden_states, + encoder_hidden_states=cross_attention_hidden_states, + rotary_embedding=rotary_embedding, + start_id=0, + end_id=self.cache_block_start, + ) + + cache_end = np.minimum(self.cache_block_start + self.cache_num_blocks, self.num_layers) + hidden_states_pre_cache = hidden_states.clone() + if (step_id - self.cache_step_start) % self.cache_step_interval == 0: + hidden_states = self._transformer_blocks_forward( + hidden_states=hidden_states, + encoder_hidden_states=cross_attention_hidden_states, + rotary_embedding=rotary_embedding, + start_id=self.cache_block_start, + end_id=cache_end, + ) + self.delta_cache = hidden_states - hidden_states_pre_cache + else: + hidden_states = hidden_states_pre_cache + self.delta_cache + + if cache_end < self.num_layers: + hidden_states = self._transformer_blocks_forward( + hidden_states=hidden_states, + encoder_hidden_states=cross_attention_hidden_states, + rotary_embedding=rotary_embedding, + start_id=cache_end, + end_id=self.num_layers, + ) + + hidden_states = self.proj_out(hidden_states) + + # (batch_size, sequence_length, dim) -> (batch_size, dim, sequence_length) + # remove prepend length that has been added by global hidden states + hidden_states = hidden_states.transpose(1, 2)[:, :, 1:] + hidden_states = self.postprocess_conv(hidden_states) + hidden_states + + if not return_dict: + return (hidden_states,) + + return Transformer2DModelOutput(sample=hidden_states) + + def _transformer_blocks_forward(self, hidden_states, encoder_hidden_states, rotary_embedding, start_id, end_id): + for block in self.transformer_blocks[start_id: end_id]: + hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + rotary_embedding=rotary_embedding, + ) + return hidden_states + + def _convert_deprecated_attention_blocks(self, state_dict: OrderedDict)->None: + for i in range(self.num_layers): + self_q_weight = state_dict.pop(f"transformer_blocks.{i}.attn1.to_q.weight", None) + self_k_weight = state_dict.pop(f"transformer_blocks.{i}.attn1.to_k.weight", None) + self_v_weight = state_dict.pop(f"transformer_blocks.{i}.attn1.to_v.weight", None) + self_qkv_weight = torch.cat([self_q_weight, self_k_weight, self_v_weight], dim=0).transpose(0, 1).contiguous() + state_dict[f"transformer_blocks.{i}.attn1.to_qkv.weight"] = self_qkv_weight + + cross_q_weight = state_dict.pop(f"transformer_blocks.{i}.attn2.to_q.weight", None) + cross_k_weight = state_dict.pop(f"transformer_blocks.{i}.attn2.to_k.weight", None) + cross_v_weight = state_dict.pop(f"transformer_blocks.{i}.attn2.to_v.weight", None) + cross_q_weight = cross_q_weight.transpose(0, 1).contiguous() + cross_kv_weight = torch.cat([cross_k_weight, cross_v_weight], dim=0).transpose(0, 1).contiguous() + state_dict[f"transformer_blocks.{i}.attn2.to_qkv.q_weight"] = cross_q_weight + state_dict[f"transformer_blocks.{i}.attn2.to_qkv.kv_weight"] = cross_kv_weight \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_audio_open_1.0/stable_audio_open/stableaudio/pipeline/__init__.py b/MindIE/MindIE-Torch/built-in/foundation/stable_audio_open_1.0/stable_audio_open/stableaudio/pipeline/__init__.py new file mode 100644 index 0000000000..a283da8ecd --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_audio_open_1.0/stable_audio_open/stableaudio/pipeline/__init__.py @@ -0,0 +1 @@ +from .pipeline_stable_audio import StableAudioPipeline, StableAudioProjectionModel \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_audio_open_1.0/stable_audio_open/stableaudio/pipeline/pipeline_stable_audio.py b/MindIE/MindIE-Torch/built-in/foundation/stable_audio_open_1.0/stable_audio_open/stableaudio/pipeline/pipeline_stable_audio.py new file mode 100644 index 0000000000..3abd5ecc9d --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_audio_open_1.0/stable_audio_open/stableaudio/pipeline/pipeline_stable_audio.py @@ -0,0 +1,754 @@ +# Copyright 2024 Stability AI and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Callable, List, Optional, Union + +import torch +from transformers import ( + T5EncoderModel, + T5Tokenizer, + T5TokenizerFast, +) + +from diffusers.utils.torch_utils import randn_tensor +from diffusers.pipelines.pipeline_utils import AudioPipelineOutput, DiffusionPipeline +from diffusers.pipelines.stable_audio.modeling_stable_audio import StableAudioProjectionModel +from diffusers.models.embeddings import get_1d_rotary_pos_embed +from diffusers.models.autoencoders.autoencoder_oobleck import AutoencoderOobleck +from diffusers.utils import ( + logging, + replace_example_docstring, +) + +from ..models import StableAudioDiTModel +from ..schedulers.scheduling_cosine_dpmsolver_multistep import CosineDPMSolverMultistepScheduler + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import scipy + >>> import torch + >>> import soundfile as sf + >>> from diffusers import StableAudioPipeline + + >>> repo_id = "stabilityai/stable-audio-open-1.0" + >>> pipe = StableAudioPipeline.from_pretrained(repo_id, torch_dtype=torch.float16) + >>> pipe = pipe.to("cuda") + + >>> # define the prompts + >>> prompt = "The sound of a hammer hitting a wooden surface." + >>> negative_prompt = "Low quality." + + >>> # set the seed for generator + >>> generator = torch.Generator("cuda").manual_seed(0) + + >>> # run the generation + >>> audio = pipe( + ... prompt, + ... negative_prompt=negative_prompt, + ... num_inference_steps=200, + ... audio_end_in_s=10.0, + ... num_waveforms_per_prompt=3, + ... generator=generator, + ... ).audios + + >>> output = audio[0].T.float().cpu().numpy() + >>> sf.write("hammer.wav", output, pipe.vae.sampling_rate) + ``` +""" + + +class StableAudioPipeline(DiffusionPipeline): + r""" + Pipeline for text-to-audio generation using StableAudio. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Args: + vae ([`AutoencoderOobleck`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + text_encoder ([`~transformers.T5EncoderModel`]): + Frozen text-encoder. StableAudio uses the encoder of + [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the + [google-t5/t5-base](https://huggingface.co/google-t5/t5-base) variant. + projection_model ([`StableAudioProjectionModel`]): + A trained model used to linearly project the hidden-states from the text encoder model and the start and + end seconds. The projected hidden-states from the encoder and the conditional seconds are concatenated to + give the input to the transformer model. + tokenizer ([`~transformers.T5Tokenizer`]): + Tokenizer to tokenize text for the frozen text-encoder. + transformer ([`StableAudioDiTModel`]): + A `StableAudioDiTModel` to denoise the encoded audio latents. + scheduler ([`CosineDPMSolverMultistepScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded audio latents. + """ + + model_cpu_offload_seq = "text_encoder->projection_model->transformer->vae" + + def __init__( + self, + vae: AutoencoderOobleck, + text_encoder: T5EncoderModel, + projection_model: StableAudioProjectionModel, + tokenizer: Union[T5Tokenizer, T5TokenizerFast], + transformer: StableAudioDiTModel, + scheduler: CosineDPMSolverMultistepScheduler, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + projection_model=projection_model, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + ) + self.rotary_embed_dim = self.transformer.config.attention_head_dim // 2 + + # Copied from diffusers.pipelines.pipeline_utils.StableDiffusionMixin.enable_vae_slicing + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + # Copied from diffusers.pipelines.pipeline_utils.StableDiffusionMixin.disable_vae_slicing + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + def encode_prompt( + self, + prompt, + device, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + negative_attention_mask: Optional[torch.LongTensor] = None, + ): + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # 1. Tokenize text + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + attention_mask = text_inputs.attention_mask + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + f"The following part of your input was truncated because {self.text_encoder.config.model_type} can " + f"only handle sequences up to {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + text_input_ids = text_input_ids.to(device) + attention_mask = attention_mask.to(device) + + # 2. Text encoder forward + self.text_encoder.eval() + prompt_embeds = self.text_encoder( + text_input_ids, + attention_mask=attention_mask, + ) + prompt_embeds = prompt_embeds[0] + + if do_classifier_free_guidance and negative_prompt is not None: + uncond_tokens: List[str] + if type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # 1. Tokenize text + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + + uncond_input_ids = uncond_input.input_ids.to(device) + negative_attention_mask = uncond_input.attention_mask.to(device) + + # 2. Text encoder forward + self.text_encoder.eval() + negative_prompt_embeds = self.text_encoder( + uncond_input_ids, + attention_mask=negative_attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if negative_attention_mask is not None: + # set the masked tokens to the null embed + negative_prompt_embeds = torch.where( + negative_attention_mask.to(torch.bool).unsqueeze(2), negative_prompt_embeds, 0.0 + ) + + # 3. Project prompt_embeds and negative_prompt_embeds + if do_classifier_free_guidance and negative_prompt_embeds is not None: + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the negative and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + if attention_mask is not None and negative_attention_mask is None: + negative_attention_mask = torch.ones_like(attention_mask) + elif attention_mask is None and negative_attention_mask is not None: + attention_mask = torch.ones_like(negative_attention_mask) + + if attention_mask is not None: + attention_mask = torch.cat([negative_attention_mask, attention_mask]) + + prompt_embeds = self.projection_model( + text_hidden_states=prompt_embeds, + ).text_hidden_states + if attention_mask is not None: + prompt_embeds = prompt_embeds * attention_mask.unsqueeze(-1).to(prompt_embeds.dtype) + prompt_embeds = prompt_embeds * attention_mask.unsqueeze(-1).to(prompt_embeds.dtype) + + return prompt_embeds + + def encode_duration( + self, + audio_start_in_s, + audio_end_in_s, + device, + do_classifier_free_guidance, + batch_size, + ): + audio_start_in_s = audio_start_in_s if isinstance(audio_start_in_s, list) else [audio_start_in_s] + audio_end_in_s = audio_end_in_s if isinstance(audio_end_in_s, list) else [audio_end_in_s] + + if len(audio_start_in_s) == 1: + audio_start_in_s = audio_start_in_s * batch_size + if len(audio_end_in_s) == 1: + audio_end_in_s = audio_end_in_s * batch_size + + # Cast the inputs to floats + audio_start_in_s = [float(x) for x in audio_start_in_s] + audio_start_in_s = torch.tensor(audio_start_in_s).to(device) + + audio_end_in_s = [float(x) for x in audio_end_in_s] + audio_end_in_s = torch.tensor(audio_end_in_s).to(device) + + projection_output = self.projection_model( + start_seconds=audio_start_in_s, + end_seconds=audio_end_in_s, + ) + seconds_start_hidden_states = projection_output.seconds_start_hidden_states + seconds_end_hidden_states = projection_output.seconds_end_hidden_states + + # For classifier free guidance, we need to do two forward passes. + # Here we repeat the audio hidden states to avoid doing two forward passes + if do_classifier_free_guidance: + seconds_start_hidden_states = torch.cat([seconds_start_hidden_states, seconds_start_hidden_states], dim=0) + seconds_end_hidden_states = torch.cat([seconds_end_hidden_states, seconds_end_hidden_states], dim=0) + + return seconds_start_hidden_states, seconds_end_hidden_states + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + audio_start_in_s, + audio_end_in_s, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + attention_mask=None, + negative_attention_mask=None, + initial_audio_waveforms=None, + initial_audio_sampling_rate=None, + ): + if audio_end_in_s < audio_start_in_s: + raise ValueError( + f"`audio_end_in_s={audio_end_in_s}' must be higher than 'audio_start_in_s={audio_start_in_s}` but " + ) + + if ( + audio_start_in_s < self.projection_model.config.min_value + or audio_start_in_s > self.projection_model.config.max_value + ): + raise ValueError( + f"`audio_start_in_s` must be greater than or equal to {self.projection_model.config.min_value}, and lower than or equal to {self.projection_model.config.max_value} but " + f"is {audio_start_in_s}." + ) + + if ( + audio_end_in_s < self.projection_model.config.min_value + or audio_end_in_s > self.projection_model.config.max_value + ): + raise ValueError( + f"`audio_end_in_s` must be greater than or equal to {self.projection_model.config.min_value}, and lower than or equal to {self.projection_model.config.max_value} but " + f"is {audio_end_in_s}." + ) + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and (prompt_embeds is None): + raise ValueError( + "Provide either `prompt`, or `prompt_embeds`. Cannot leave" + "`prompt` undefined without specifying `prompt_embeds`." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + if attention_mask is not None and attention_mask.shape != prompt_embeds.shape[:2]: + raise ValueError( + "`attention_mask should have the same batch size and sequence length as `prompt_embeds`, but got:" + f"`attention_mask: {attention_mask.shape} != `prompt_embeds` {prompt_embeds.shape}" + ) + + if initial_audio_sampling_rate is None and initial_audio_waveforms is not None: + raise ValueError( + "`initial_audio_waveforms' is provided but the sampling rate is not. Make sure to pass `initial_audio_sampling_rate`." + ) + + if initial_audio_sampling_rate is not None and initial_audio_sampling_rate != self.vae.sampling_rate: + raise ValueError( + f"`initial_audio_sampling_rate` must be {self.vae.hop_length}' but is `{initial_audio_sampling_rate}`." + "Make sure to resample the `initial_audio_waveforms` and to correct the sampling rate. " + ) + + def prepare_latents( + self, + batch_size, + num_channels_vae, + sample_size, + dtype, + device, + generator, + latents=None, + initial_audio_waveforms=None, + num_waveforms_per_prompt=None, + audio_channels=None, + ): + shape = (batch_size, num_channels_vae, sample_size) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + + # encode the initial audio for use by the model + if initial_audio_waveforms is not None: + # check dimension + if initial_audio_waveforms.ndim == 2: + initial_audio_waveforms = initial_audio_waveforms.unsqueeze(1) + elif initial_audio_waveforms.ndim != 3: + raise ValueError( + f"`initial_audio_waveforms` must be of shape `(batch_size, num_channels, audio_length)` or `(batch_size, audio_length)` but has `{initial_audio_waveforms.ndim}` dimensions" + ) + + audio_vae_length = self.transformer.config.sample_size * self.vae.hop_length + audio_shape = (batch_size // num_waveforms_per_prompt, audio_channels, audio_vae_length) + + # check num_channels + if initial_audio_waveforms.shape[1] == 1 and audio_channels == 2: + initial_audio_waveforms = initial_audio_waveforms.repeat(1, 2, 1) + elif initial_audio_waveforms.shape[1] == 2 and audio_channels == 1: + initial_audio_waveforms = initial_audio_waveforms.mean(1, keepdim=True) + + if initial_audio_waveforms.shape[:2] != audio_shape[:2]: + raise ValueError( + f"`initial_audio_waveforms` must be of shape `(batch_size, num_channels, audio_length)` or `(batch_size, audio_length)` but is of shape `{initial_audio_waveforms.shape}`" + ) + + # crop or pad + audio_length = initial_audio_waveforms.shape[-1] + if audio_length < audio_vae_length: + logger.warning( + f"The provided input waveform is shorter ({audio_length}) than the required audio length ({audio_vae_length}) of the model and will thus be padded." + ) + elif audio_length > audio_vae_length: + logger.warning( + f"The provided input waveform is longer ({audio_length}) than the required audio length ({audio_vae_length}) of the model and will thus be cropped." + ) + + audio = initial_audio_waveforms.new_zeros(audio_shape) + audio[:, :, : min(audio_length, audio_vae_length)] = initial_audio_waveforms[:, :, :audio_vae_length] + + encoded_audio = self.vae.encode(audio).latent_dist.sample(generator) + encoded_audio = encoded_audio.repeat((num_waveforms_per_prompt, 1, 1)) + latents = encoded_audio + latents + return latents + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + audio_end_in_s: Optional[float] = None, + audio_start_in_s: Optional[float] = 0.0, + num_inference_steps: int = 100, + guidance_scale: float = 7.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_waveforms_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + initial_audio_waveforms: Optional[torch.Tensor] = None, + initial_audio_sampling_rate: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + negative_attention_mask: Optional[torch.LongTensor] = None, + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, + callback_steps: Optional[int] = 1, + output_type: Optional[str] = "pt", + use_cache: Optional[bool] = False, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide audio generation. If not defined, you need to pass `prompt_embeds`. + audio_end_in_s (`float`, *optional*, defaults to 47.55): + Audio end index in seconds. + audio_start_in_s (`float`, *optional*, defaults to 0): + Audio start index in seconds. + num_inference_steps (`int`, *optional*, defaults to 100): + The number of denoising steps. More denoising steps usually lead to a higher quality audio at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.0): + A higher guidance scale value encourages the model to generate audio that is closely linked to the text + `prompt` at the expense of lower sound quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide what to not include in audio generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_waveforms_per_prompt (`int`, *optional*, defaults to 1): + The number of waveforms to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies + to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for audio + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + initial_audio_waveforms (`torch.Tensor`, *optional*): + Optional initial audio waveforms to use as the initial audio waveform for generation. Must be of shape + `(batch_size, num_channels, audio_length)` or `(batch_size, audio_length)`, where `batch_size` + corresponds to the number of prompts passed to the model. + initial_audio_sampling_rate (`int`, *optional*): + Sampling rate of the `initial_audio_waveforms`, if they are provided. Must be the same as the model. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-computed text embeddings from the text encoder model. Can be used to easily tweak text inputs, + *e.g.* prompt weighting. If not provided, text embeddings will be computed from `prompt` input + argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-computed negative text embeddings from the text encoder model. Can be used to easily tweak text + inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be computed from + `negative_prompt` input argument. + attention_mask (`torch.LongTensor`, *optional*): + Pre-computed attention mask to be applied to the `prompt_embeds`. If not provided, attention mask will + be computed from `prompt` input argument. + negative_attention_mask (`torch.LongTensor`, *optional*): + Pre-computed attention mask to be applied to the `negative_text_audio_duration_embeds`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that calls every `callback_steps` steps during inference. The function is called with the + following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function is called. If not specified, the callback is called at + every step. + output_type (`str`, *optional*, defaults to `"pt"`): + The output format of the generated audio. Choose between `"np"` to return a NumPy `np.ndarray` or + `"pt"` to return a PyTorch `torch.Tensor` object. Set to `"latent"` to return the latent diffusion + model (LDM) output. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated audio. + """ + # 0. Convert audio input length from seconds to latent length + downsample_ratio = self.vae.hop_length + + max_audio_length_in_s = self.transformer.config.sample_size * downsample_ratio / self.vae.config.sampling_rate + if audio_end_in_s is None: + audio_end_in_s = max_audio_length_in_s + + if audio_end_in_s - audio_start_in_s > max_audio_length_in_s: + raise ValueError( + f"The total audio length requested ({audio_end_in_s-audio_start_in_s}s) is longer than the model maximum possible length ({max_audio_length_in_s}). Make sure that 'audio_end_in_s-audio_start_in_s<={max_audio_length_in_s}'." + ) + + waveform_start = int(audio_start_in_s * self.vae.config.sampling_rate) + waveform_end = int(audio_end_in_s * self.vae.config.sampling_rate) + waveform_length = int(self.transformer.config.sample_size) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + audio_start_in_s, + audio_end_in_s, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + attention_mask, + negative_attention_mask, + initial_audio_waveforms, + initial_audio_sampling_rate, + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + prompt_embeds = self.encode_prompt( + prompt, + device, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + attention_mask, + negative_attention_mask, + ) + + # Encode duration + seconds_start_hidden_states, seconds_end_hidden_states = self.encode_duration( + audio_start_in_s, + audio_end_in_s, + device, + do_classifier_free_guidance and (negative_prompt is not None or negative_prompt_embeds is not None), + batch_size, + ) + + # Create text_audio_duration_embeds and audio_duration_embeds + text_audio_duration_embeds = torch.cat( + [prompt_embeds, seconds_start_hidden_states, seconds_end_hidden_states], dim=1 + ) + + audio_duration_embeds = torch.cat([seconds_start_hidden_states, seconds_end_hidden_states], dim=2) + + # In case of classifier free guidance without negative prompt, we need to create unconditional embeddings and + # to concatenate it to the embeddings + if do_classifier_free_guidance and negative_prompt_embeds is None and negative_prompt is None: + negative_text_audio_duration_embeds = torch.zeros_like( + text_audio_duration_embeds, device=text_audio_duration_embeds.device + ) + text_audio_duration_embeds = torch.cat( + [negative_text_audio_duration_embeds, text_audio_duration_embeds], dim=0 + ) + audio_duration_embeds = torch.cat([audio_duration_embeds, audio_duration_embeds], dim=0) + + bs_embed, seq_len, hidden_size = text_audio_duration_embeds.shape + # duplicate audio_duration_embeds and text_audio_duration_embeds for each generation per prompt, using mps friendly method + text_audio_duration_embeds = text_audio_duration_embeds.repeat(1, num_waveforms_per_prompt, 1) + text_audio_duration_embeds = text_audio_duration_embeds.view( + bs_embed * num_waveforms_per_prompt, seq_len, hidden_size + ) + + audio_duration_embeds = audio_duration_embeds.repeat(1, num_waveforms_per_prompt, 1) + audio_duration_embeds = audio_duration_embeds.view( + bs_embed * num_waveforms_per_prompt, -1, audio_duration_embeds.shape[-1] + ) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_vae = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_waveforms_per_prompt, + num_channels_vae, + waveform_length, + text_audio_duration_embeds.dtype, + device, + generator, + latents, + initial_audio_waveforms, + num_waveforms_per_prompt, + audio_channels=self.vae.config.audio_channels, + ) + + # 6. Prepare extra step kwargs + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Prepare rotary positional embedding + rotary_embedding = get_1d_rotary_pos_embed( + self.rotary_embed_dim, + latents.shape[2] + audio_duration_embeds.shape[1], + use_real=True, + repeat_interleave_real=False, + ) + + cos = rotary_embedding[0][None, :, None, :].to(latents.device).to(torch.float16) + sin = rotary_embedding[1][None, :, None, :].to(latents.device).to(torch.float16) + rotary_embedding = (cos, sin) + + # 8. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.transformer( + i, + latent_model_input, + t.unsqueeze(0), + encoder_hidden_states=text_audio_duration_embeds, + global_hidden_states=audio_duration_embeds, + rotary_embedding=rotary_embedding, + return_dict=False, + use_cache=use_cache, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + # 9. Post-processing + if not output_type == "latent": + audio = self.vae.decode(latents).sample + else: + return AudioPipelineOutput(audios=latents) + + audio = audio[:, :, waveform_start:waveform_end] + + if output_type == "np": + audio = audio.cpu().float().numpy() + + self.maybe_free_model_hooks() + + if not return_dict: + return (audio,) + + return AudioPipelineOutput(audios=audio) diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_audio_open_1.0/stable_audio_open/stableaudio/schedulers/__init__.py b/MindIE/MindIE-Torch/built-in/foundation/stable_audio_open_1.0/stable_audio_open/stableaudio/schedulers/__init__.py new file mode 100644 index 0000000000..5bad3d9b15 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_audio_open_1.0/stable_audio_open/stableaudio/schedulers/__init__.py @@ -0,0 +1 @@ +from .scheduling_cosine_dpmsolver_multistep import CosineDPMSolverMultistepScheduler \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_audio_open_1.0/stable_audio_open/stableaudio/schedulers/scheduling_cosine_dpmsolver_multistep.py b/MindIE/MindIE-Torch/built-in/foundation/stable_audio_open_1.0/stable_audio_open/stableaudio/schedulers/scheduling_cosine_dpmsolver_multistep.py new file mode 100644 index 0000000000..7a33eae2b9 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_audio_open_1.0/stable_audio_open/stableaudio/schedulers/scheduling_cosine_dpmsolver_multistep.py @@ -0,0 +1,572 @@ +# Copyright 2024 TSAIL Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# DISCLAIMER: This file is strongly influenced by https://github.com/LuChengTHU/dpm-solver and https://github.com/NVlabs/edm + +import math +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput +from .scheduling_dpmsolver_sde import BrownianTreeNoiseSampler + + +class CosineDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): + """ + Implements a variant of `DPMSolverMultistepScheduler` with cosine schedule, proposed by Nichol and Dhariwal (2021). + This scheduler was used in Stable Audio Open [1]. + + [1] Evans, Parker, et al. "Stable Audio Open" https://arxiv.org/abs/2407.14358 + + This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic + methods the library implements for all schedulers such as loading and saving. + + Args: + sigma_min (`float`, *optional*, defaults to 0.3): + Minimum noise magnitude in the sigma schedule. This was set to 0.3 in Stable Audio Open [1]. + sigma_max (`float`, *optional*, defaults to 500): + Maximum noise magnitude in the sigma schedule. This was set to 500 in Stable Audio Open [1]. + sigma_data (`float`, *optional*, defaults to 1.0): + The standard deviation of the data distribution. This is set to 1.0 in Stable Audio Open [1]. + sigma_schedule (`str`, *optional*, defaults to `exponential`): + Sigma schedule to compute the `sigmas`. By default, we the schedule introduced in the EDM paper + (https://arxiv.org/abs/2206.00364). Other acceptable value is "exponential". The exponential schedule was + incorporated in this model: https://huggingface.co/stabilityai/cosxl. + num_train_timesteps (`int`, defaults to 1000): + The number of diffusion steps to train the model. + solver_order (`int`, defaults to 2): + The DPMSolver order which can be `1` or `2`. It is recommended to use `solver_order=2`. + prediction_type (`str`, defaults to `v_prediction`, *optional*): + Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process), + `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen + Video](https://imagen.research.google/video/paper.pdf) paper). + solver_type (`str`, defaults to `midpoint`): + Solver type for the second-order solver; can be `midpoint` or `heun`. The solver type slightly affects the + sample quality, especially for a small number of steps. It is recommended to use `midpoint` solvers. + lower_order_final (`bool`, defaults to `True`): + Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can + stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10. + euler_at_final (`bool`, defaults to `False`): + Whether to use Euler's method in the final step. It is a trade-off between numerical stability and detail + richness. This can stabilize the sampling of the SDE variant of DPMSolver for small number of inference + steps, but sometimes may result in blurring. + final_sigmas_type (`str`, defaults to `"zero"`): + The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final + sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0. + """ + + _compatibles = [] + order = 1 + + @register_to_config + def __init__( + self, + sigma_min: float = 0.3, + sigma_max: float = 500, + sigma_data: float = 1.0, + sigma_schedule: str = "exponential", + num_train_timesteps: int = 1000, + solver_order: int = 2, + prediction_type: str = "v_prediction", + rho: float = 7.0, + solver_type: str = "midpoint", + lower_order_final: bool = True, + euler_at_final: bool = False, + final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min" + ): + if solver_type not in ["midpoint", "heun"]: + if solver_type in ["logrho", "bh1", "bh2"]: + self.register_to_config(solver_type="midpoint") + else: + raise NotImplementedError(f"{solver_type} is not implemented for {self.__class__}") + + ramp = torch.linspace(0, 1, num_train_timesteps) + if sigma_schedule == "karras": + sigmas = self._compute_karras_sigmas(ramp) + elif sigma_schedule == "exponential": + sigmas = self._compute_exponential_sigmas(ramp) + + self.timesteps = self.precondition_noise(sigmas) + + self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)]) + + # setable values + self.num_inference_steps = None + self.model_outputs = [None] * solver_order + self.lower_order_nums = 0 + self._step_index = None + self._begin_index = None + self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication + + @property + def init_noise_sigma(self): + # standard deviation of the initial noise distribution + return (self.config.sigma_max**2 + 1) ** 0.5 + + @property + def step_index(self): + """ + The index counter for current timestep. It will increase 1 after each scheduler step. + """ + return self._step_index + + @property + def begin_index(self): + """ + The index for the first timestep. It should be set from pipeline with `set_begin_index` method. + """ + return self._begin_index + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index + def set_begin_index(self, begin_index: int = 0): + """ + Sets the begin index for the scheduler. This function should be run from pipeline before the inference. + + Args: + begin_index (`int`): + The begin index for the scheduler. + """ + self._begin_index = begin_index + + # Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler.precondition_inputs + def precondition_inputs(self, sample, sigma): + c_in = 1 / ((sigma**2 + self.config.sigma_data**2) ** 0.5) + scaled_sample = sample * c_in + return scaled_sample + + def precondition_noise(self, sigma): + if not isinstance(sigma, torch.Tensor): + sigma = torch.tensor([sigma]) + + return sigma.atan() / math.pi * 2 + + # Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler.precondition_outputs + def precondition_outputs(self, sample, model_output, sigma): + sigma_data = self.config.sigma_data + c_skip = sigma_data**2 / (sigma**2 + sigma_data**2) + + if self.config.prediction_type == "epsilon": + c_out = sigma * sigma_data / (sigma**2 + sigma_data**2) ** 0.5 + elif self.config.prediction_type == "v_prediction": + c_out = -sigma * sigma_data / (sigma**2 + sigma_data**2) ** 0.5 + else: + raise ValueError(f"Prediction type {self.config.prediction_type} is not supported.") + + denoised = c_skip * sample + c_out * model_output + + return denoised + + # Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler.scale_model_input + def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor: + """ + Ensures interchangeability with schedulers that need to scale the denoising model input depending on the + current timestep. Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the Euler algorithm. + + Args: + sample (`torch.Tensor`): + The input sample. + timestep (`int`, *optional*): + The current timestep in the diffusion chain. + + Returns: + `torch.Tensor`: + A scaled input sample. + """ + if self.step_index is None: + self._init_step_index(timestep) + + sigma = self.sigmas[self.step_index] + sample = self.precondition_inputs(sample, sigma) + + self.is_scale_input_called = True + return sample + + def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torch.device] = None): + """ + Sets the discrete timesteps used for the diffusion chain (to be run before inference). + + Args: + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + """ + + self.num_inference_steps = num_inference_steps + + ramp = torch.linspace(0, 1, self.num_inference_steps) + if self.config.sigma_schedule == "karras": + sigmas = self._compute_karras_sigmas(ramp) + elif self.config.sigma_schedule == "exponential": + sigmas = self._compute_exponential_sigmas(ramp) + + sigmas = sigmas.to(dtype=torch.float32, device=device) + self.timesteps = self.precondition_noise(sigmas) + + if self.config.final_sigmas_type == "sigma_min": + sigma_last = self.config.sigma_min + elif self.config.final_sigmas_type == "zero": + sigma_last = 0 + else: + raise ValueError( + f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}" + ) + + self.sigmas = torch.cat([sigmas, torch.tensor([sigma_last], dtype=torch.float32, device=device)]) + + self.model_outputs = [ + None, + ] * self.config.solver_order + self.lower_order_nums = 0 + + # add an index counter for schedulers that allow duplicated timesteps + self._step_index = None + self._begin_index = None + self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication + + # if a noise sampler is used, reinitialise it + self.noise_sampler = None + + # Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler._compute_karras_sigmas + def _compute_karras_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch.Tensor: + """Constructs the noise schedule of Karras et al. (2022).""" + sigma_min = sigma_min or self.config.sigma_min + sigma_max = sigma_max or self.config.sigma_max + + rho = self.config.rho + min_inv_rho = sigma_min ** (1 / rho) + max_inv_rho = sigma_max ** (1 / rho) + sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho + return sigmas + + # Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler._compute_exponential_sigmas + def _compute_exponential_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch.Tensor: + """Implementation closely follows k-diffusion. + + https://github.com/crowsonkb/k-diffusion/blob/6ab5146d4a5ef63901326489f31f1d8e7dd36b48/k_diffusion/sampling.py#L26 + """ + sigma_min = sigma_min or self.config.sigma_min + sigma_max = sigma_max or self.config.sigma_max + sigmas = torch.linspace(math.log(sigma_min), math.log(sigma_max), len(ramp)).exp().flip(0) + return sigmas + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t + def _sigma_to_t(self, sigma, log_sigmas): + # get log sigma + log_sigma = np.log(np.maximum(sigma, 1e-10)) + + # get distribution + dists = log_sigma - log_sigmas[:, np.newaxis] + + # get sigmas range + low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2) + high_idx = low_idx + 1 + + low = log_sigmas[low_idx] + high = log_sigmas[high_idx] + + # interpolate sigmas + w = (low - log_sigma) / (low - high) + w = np.clip(w, 0, 1) + + # transform interpolation to time range + t = (1 - w) * low_idx + w * high_idx + t = t.reshape(sigma.shape) + return t + + def _sigma_to_alpha_sigma_t(self, sigma): + alpha_t = torch.tensor(1) # Inputs are pre-scaled before going into unet, so alpha_t = 1 + sigma_t = sigma + + return alpha_t, sigma_t + + def convert_model_output( + self, + model_output: torch.Tensor, + sample: torch.Tensor = None, + ) -> torch.Tensor: + """ + Convert the model output to the corresponding type the DPMSolver/DPMSolver++ algorithm needs. DPM-Solver is + designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to discretize an + integral of the data prediction model. + + + + The algorithm and model type are decoupled. You can use either DPMSolver or DPMSolver++ for both noise + prediction and data prediction models. + + + + Args: + model_output (`torch.Tensor`): + The direct output from the learned diffusion model. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + + Returns: + `torch.Tensor`: + The converted model output. + """ + sigma = self.sigmas[self.step_index] + x0_pred = self.precondition_outputs(sample, model_output, sigma) + + return x0_pred + + def dpm_solver_first_order_update( + self, + model_output: torch.Tensor, + sample: torch.Tensor = None, + noise: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + One step for the first-order DPMSolver (equivalent to DDIM). + + Args: + model_output (`torch.Tensor`): + The direct output from the learned diffusion model. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + + Returns: + `torch.Tensor`: + The sample tensor at the previous timestep. + """ + sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[self.step_index] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s) + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s = torch.log(alpha_s) - torch.log(sigma_s) + + h = lambda_t - lambda_s + assert noise is not None + x_t = ( + (sigma_t / sigma_s * torch.exp(-h)) * sample + + (alpha_t * (1 - torch.exp(-2.0 * h))) * model_output + + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise + ) + + return x_t + + def multistep_dpm_solver_second_order_update( + self, + model_output_list: List[torch.Tensor], + sample: torch.Tensor = None, + noise: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + One step for the second-order multistep DPMSolver. + + Args: + model_output_list (`List[torch.Tensor]`): + The direct outputs from learned diffusion model at current and latter timesteps. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + + Returns: + `torch.Tensor`: + The sample tensor at the previous timestep. + """ + sigma_t, sigma_s0, sigma_s1 = ( + self.sigmas[self.step_index + 1], + self.sigmas[self.step_index], + self.sigmas[self.step_index - 1], + ) + + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) + alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1) + + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) + lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1) + + m0, m1 = model_output_list[-1], model_output_list[-2] + + h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1 + r0 = h_0 / h + D0, D1 = m0, (1.0 / r0) * (m0 - m1) + + # sde-dpmsolver++ + assert noise is not None + if self.config.solver_type == "midpoint": + x_t = ( + (sigma_t / sigma_s0 * torch.exp(-h)) * sample + + (alpha_t * (1 - torch.exp(-2.0 * h))) * D0 + + 0.5 * (alpha_t * (1 - torch.exp(-2.0 * h))) * D1 + + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise + ) + elif self.config.solver_type == "heun": + x_t = ( + (sigma_t / sigma_s0 * torch.exp(-h)) * sample + + (alpha_t * (1 - torch.exp(-2.0 * h))) * D0 + + (alpha_t * ((1.0 - torch.exp(-2.0 * h)) / (-2.0 * h) + 1.0)) * D1 + + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise + ) + + return x_t + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep + def index_for_timestep(self, timestep, schedule_timesteps=None): + if schedule_timesteps is None: + schedule_timesteps = self.timesteps + + index_candidates = (schedule_timesteps == timestep).nonzero() + + if len(index_candidates) == 0: + step_index = len(self.timesteps) - 1 + # The sigma index that is taken for the **very** first `step` + # is always the second index (or the last index if there is only 1) + # This way we can ensure we don't accidentally skip a sigma in + # case we start in the middle of the denoising schedule (e.g. for image-to-image) + elif len(index_candidates) > 1: + step_index = index_candidates[1].item() + else: + step_index = index_candidates[0].item() + + return step_index + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index + def _init_step_index(self, timestep): + """ + Initialize the step_index counter for the scheduler. + """ + + if self.begin_index is None: + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + self._step_index = self.index_for_timestep(timestep) + else: + self._step_index = self._begin_index + + def step( + self, + model_output: torch.Tensor, + timestep: Union[int, torch.Tensor], + sample: torch.Tensor, + generator=None, + return_dict: bool = True, + ) -> Union[SchedulerOutput, Tuple]: + """ + Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with + the multistep DPMSolver. + + Args: + model_output (`torch.Tensor`): + The direct output from learned diffusion model. + timestep (`int`): + The current discrete timestep in the diffusion chain. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + generator (`torch.Generator`, *optional*): + A random number generator. + return_dict (`bool`): + Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`. + + Returns: + [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: + If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a + tuple is returned where the first element is the sample tensor. + + """ + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + if self.step_index is None: + self._init_step_index(timestep) + + # Improve numerical stability for small number of steps + lower_order_final = (self.step_index == len(self.timesteps) - 1) and ( + self.config.euler_at_final + or (self.config.lower_order_final and len(self.timesteps) < 15) + or self.config.final_sigmas_type == "zero" + ) + lower_order_second = ( + (self.step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15 + ) + + model_output = self.convert_model_output(model_output, sample=sample) + for i in range(self.config.solver_order - 1): + self.model_outputs[i] = self.model_outputs[i + 1] + self.model_outputs[-1] = model_output + + if self.noise_sampler is None: + seed = None + if generator is not None: + seed = ( + [g.initial_seed() for g in generator] if isinstance(generator, list) else generator.initial_seed() + ) + self.noise_sampler = BrownianTreeNoiseSampler( + model_output, sigma_min=self.config.sigma_min, sigma_max=self.config.sigma_max, seed=seed + ) + noise = self.noise_sampler(self.sigmas[self.step_index], self.sigmas[self.step_index + 1]).to( + model_output.device + ) + + if self.config.solver_order == 1 or self.lower_order_nums < 1 or lower_order_final: + prev_sample = self.dpm_solver_first_order_update(model_output, sample=sample, noise=noise) + elif self.config.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second: + prev_sample = self.multistep_dpm_solver_second_order_update(self.model_outputs, sample=sample, noise=noise) + + if self.lower_order_nums < self.config.solver_order: + self.lower_order_nums += 1 + + # upon completion increase step index by one + self._step_index += 1 + + if not return_dict: + return (prev_sample,) + + return SchedulerOutput(prev_sample=prev_sample) + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise + def add_noise( + self, + original_samples: torch.Tensor, + noise: torch.Tensor, + timesteps: torch.Tensor, + ) -> torch.Tensor: + # Make sure sigmas and timesteps have the same device and dtype as original_samples + sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) + if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): + # mps does not support float64 + schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32) + timesteps = timesteps.to(original_samples.device, dtype=torch.float32) + else: + schedule_timesteps = self.timesteps.to(original_samples.device) + timesteps = timesteps.to(original_samples.device) + + # self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index + if self.begin_index is None: + step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps] + elif self.step_index is not None: + # add_noise is called after first denoising step (for inpainting) + step_indices = [self.step_index] * timesteps.shape[0] + else: + # add noise is called before first denoising step to create initial latent(img2img) + step_indices = [self.begin_index] * timesteps.shape[0] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < len(original_samples.shape): + sigma = sigma.unsqueeze(-1) + + noisy_samples = original_samples + noise * sigma + return noisy_samples + + def __len__(self): + return self.config.num_train_timesteps diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_audio_open_1.0/stable_audio_open/stableaudio/schedulers/scheduling_dpmsolver_sde.py b/MindIE/MindIE-Torch/built-in/foundation/stable_audio_open_1.0/stable_audio_open/stableaudio/schedulers/scheduling_dpmsolver_sde.py new file mode 100644 index 0000000000..2193d4175d --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_audio_open_1.0/stable_audio_open/stableaudio/schedulers/scheduling_dpmsolver_sde.py @@ -0,0 +1,70 @@ +# Copyright 2024 Katherine Crowson, The HuggingFace Team and hlky. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch + +from .scheduling_utils import BrownianTree + + +class BatchedBrownianTree: + """A wrapper around torchsde.BrownianTree that enables batches of entropy.""" + + def __init__(self, x, t0, t1, seed=None, **kwargs): + t0, t1, self.sign = self.sort(t0, t1) + w0 = kwargs.get("w0", torch.zeros_like(x)) + if seed is None: + seed = torch.randint(0, 2**63 - 1, []).item() + self.batched = True + try: + assert len(seed) == x.shape[0] + w0 = w0[0] + except TypeError: + seed = [seed] + self.batched = False + self.trees = [BrownianTree(t0, w0, t1, entropy=s, **kwargs) for s in seed] + + @staticmethod + def sort(a, b): + return (a, b, 1) if a < b else (b, a, -1) + + def __call__(self, t0, t1): + t0, t1, sign = self.sort(t0, t1) + w = torch.stack([tree(t0, t1) for tree in self.trees]) * (self.sign * sign) + return w if self.batched else w[0] + + +class BrownianTreeNoiseSampler: + """A noise sampler backed by a torchsde.BrownianTree. + + Args: + x (Tensor): The tensor whose shape, device and dtype to use to generate + random samples. + sigma_min (float): The low end of the valid interval. + sigma_max (float): The high end of the valid interval. + seed (int or List[int]): The random seed. If a list of seeds is + supplied instead of a single integer, then the noise sampler will use one BrownianTree per batch item, each + with its own seed. + transform (callable): A function that maps sigma to the sampler's + internal timestep. + """ + + def __init__(self, x, sigma_min, sigma_max, seed=None, transform=lambda x: x): + self.transform = transform + t0, t1 = self.transform(torch.as_tensor(sigma_min)), self.transform(torch.as_tensor(sigma_max)) + self.tree = BatchedBrownianTree(x, t0, t1, seed) + + def __call__(self, sigma, sigma_next): + t0, t1 = self.transform(torch.as_tensor(sigma)), self.transform(torch.as_tensor(sigma_next)) + return self.tree(t0, t1) / (t1 - t0).abs().sqrt() \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_audio_open_1.0/stable_audio_open/stableaudio/schedulers/scheduling_utils.py b/MindIE/MindIE-Torch/built-in/foundation/stable_audio_open_1.0/stable_audio_open/stableaudio/schedulers/scheduling_utils.py new file mode 100644 index 0000000000..9bc31717cf --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_audio_open_1.0/stable_audio_open/stableaudio/schedulers/scheduling_utils.py @@ -0,0 +1,888 @@ + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import trampoline +import warnings + +import numpy as np +import torch + +from torchsde._brownian import brownian_base +from torchsde.settings import LEVY_AREA_APPROXIMATIONS +from torchsde.types import Scalar, Optional, Tuple, Union, Tensor + +_rsqrt3 = 1 / math.sqrt(3) +_r12 = 1 / 12 + + +def _randn(size, dtype, device, seed): + generator = torch.Generator(device).manual_seed(int(seed)) + return torch.randn(size, dtype=dtype, device=device, generator=generator) + + +def _is_scalar(x): + return isinstance(x, int) or isinstance(x, float) or (isinstance(x, torch.Tensor) and x.numel() == 1) + + +def _assert_floating_tensor(name, tensor): + if not torch.is_tensor(tensor): + raise ValueError(f"{name}={tensor} should be a Tensor.") + if not tensor.is_floating_point(): + raise ValueError(f"{name}={tensor} should be floating point.") + + +def _check_tensor_info(*tensors, size, dtype, device): + """Check if sizes, dtypes, and devices of input tensors all match prescribed values.""" + tensors = list(filter(torch.is_tensor, tensors)) + + if dtype is None and len(tensors) == 0: + dtype = torch.get_default_dtype() + if device is None and len(tensors) == 0: + device = torch.device("cpu") + + sizes = [] if size is None else [size] + sizes += [t.shape for t in tensors] + + dtypes = [] if dtype is None else [dtype] + dtypes += [t.dtype for t in tensors] + + devices = [] if device is None else [device] + devices += [t.device for t in tensors] + + if len(sizes) == 0: + raise ValueError("Must either specify `size` or pass in `W` or `H` to implicitly define the size.") + + if not all(i == sizes[0] for i in sizes): + raise ValueError("Multiple sizes found. Make sure `size` and `W` or `H` are consistent.") + if not all(i == dtypes[0] for i in dtypes): + raise ValueError("Multiple dtypes found. Make sure `dtype` and `W` or `H` are consistent.") + if not all(i == devices[0] for i in devices): + raise ValueError("Multiple devices found. Make sure `device` and `W` or `H` are consistent.") + + # Make sure size is a tuple (not a torch.Size) for neat repr-printing purposes. + return tuple(sizes[0]), dtypes[0], devices[0] + + +def _davie_foster_approximation(W, H, h, levy_area_approximation, get_noise): + if levy_area_approximation in (LEVY_AREA_APPROXIMATIONS.none, LEVY_AREA_APPROXIMATIONS.space_time): + return None + elif W.ndimension() in (0, 1): + # If we have zero or one dimensions then treat the scalar / single dimension we have as batch, so that the + # Brownian motion is one dimensional and the Levy area is zero. + return torch.zeros_like(W) + else: + # Davie's approximation to the Levy area from space-time Levy area + A = H.unsqueeze(-1) * W.unsqueeze(-2) - W.unsqueeze(-1) * H.unsqueeze(-2) + noise = get_noise() + noise = noise - noise.transpose(-1, -2) # noise is skew symmetric of variance 2 + if levy_area_approximation == LEVY_AREA_APPROXIMATIONS.foster: + # Foster's additional correction to Davie's approximation + tenth_h = 0.1 * h + H_squared = H ** 2 + std = (tenth_h * (tenth_h + H_squared.unsqueeze(-1) + H_squared.unsqueeze(-2))).sqrt() + else: # davie approximation + std = math.sqrt(_r12 * h ** 2) + a_tilde = std * noise + A += a_tilde + return A + + +def _H_to_U(W: torch.Tensor, H: torch.Tensor, h: float) -> torch.Tensor: + return h * (.5 * W + H) + + +class _EmptyDict: + def __setitem__(self, key, value): + pass + + def __getitem__(self, item): + raise KeyError + + +class _LRUDict(dict): + def __init__(self, max_size): + super().__init__() + self._max_size = max_size + self._keys = [] + + def __setitem__(self, key, value): + if key in self: + self._keys.remove(key) + elif len(self) >= self._max_size: + del self[self._keys.pop(0)] + super().__setitem__(key, value) + self._keys.append(key) + + +class _Interval: + # Intervals correspond to some subinterval of the overall interval [t0, t1]. + # They are arranged as a binary tree: each node corresponds to an interval. If a node has children, they are left + # and right subintervals, which partition the parent interval. + + __slots__ = ( + # These are the things that every interval has + '_start', + '_end', + '_parent', + '_is_left', + '_top', + # These are the things that intervals which are parents also have + '_midway', + '_spawn_key', + '_depth', + '_W_seed', + '_H_seed', + '_left_a_seed', + '_right_a_seed', + '_left_child', + '_right_child') + + def __init__(self, start, end, parent, is_left, top): + self._start = top._round(start) # the left hand edge of the interval + self._end = top._round(end) # the right hand edge of the interval + self._parent = parent # our parent interval + self._is_left = is_left # are we the left or right child of our parent + self._top = top # the top-level BrownianInterval, where we cache certain state + self._midway = None # The point at which we split between left and right subintervals + + ######################################## + # Calculate increments and levy area # + ######################################## + # + # This is a little bit convoluted, so here's an explanation. + # + # The entry point is _increment_and_levy_area, below. This immediately calls _increment_and_space_time_levy_area, + # applies the space-time to full Levy area correction, and then returns. + # + # _increment_and_space_time_levy_area in turn calls a central LRU cache, as (later on) we'll need the increment and + # space-time Levy area of the parent interval to compute our own increment and space-time Levy area, and it's likely + # that our parent exists in the cache, as if we're being queried then our parent was probably queried recently as + # well. + # (The top-level BrownianInterval overrides _increment_and_space_time_levy_area to return its own increment and + # space-time Levy area, effectively holding them permanently in the cache.) + # + # If the request isn't found in the LRU cache then it computes it from its parent. + # Now it turns out that the size of our increment and space-time Levy area is really most naturally thought of as a + # property of our parent: it depends on our parent's increment, space-time Levy area, and whether we are the left or + # right interval within our parent. So _increment_and_space_time_levy_area in turn checks if we are on the + # left or right of our parent and does most of the computation using the parent's attributes. + + def _increment_and_levy_area(self): + W, H = trampoline.trampoline(self._increment_and_space_time_levy_area()) + A = _davie_foster_approximation(W, H, self._end - self._start, self._top._levy_area_approximation, + self._randn_levy) + return W, H, A + + def _increment_and_space_time_levy_area(self): + try: + return self._top._increment_and_space_time_levy_area_cache[self] + except KeyError: + parent = self._parent + + W, H = yield parent._increment_and_space_time_levy_area() + h_reciprocal = 1 / (parent._end - parent._start) + left_diff = parent._midway - parent._start + right_diff = parent._end - parent._midway + + if self._top._have_H: + left_diff_squared = left_diff ** 2 + right_diff_squared = right_diff ** 2 + left_diff_cubed = left_diff * left_diff_squared + right_diff_cubed = right_diff * right_diff_squared + + v = 0.5 * math.sqrt(left_diff * right_diff / (left_diff_cubed + right_diff_cubed)) + + a = v * left_diff_squared * h_reciprocal + b = v * right_diff_squared * h_reciprocal + c = v * _rsqrt3 + + X1 = parent._randn(parent._W_seed) + X2 = parent._randn(parent._H_seed) + + third_coeff = 2 * (a * left_diff + b * right_diff) * h_reciprocal + + if self._is_left: + first_coeff = left_diff * h_reciprocal + second_coeff = 6 * first_coeff * right_diff * h_reciprocal + out_W = first_coeff * W + second_coeff * H + third_coeff * X1 + out_H = first_coeff ** 2 * H - a * X1 + c * right_diff * X2 + else: + first_coeff = right_diff * h_reciprocal + second_coeff = 6 * first_coeff * left_diff * h_reciprocal + out_W = first_coeff * W - second_coeff * H - third_coeff * X1 + out_H = first_coeff ** 2 * H - b * X1 - c * left_diff * X2 + else: + # Don't compute space-time Levy area unless we need to + + mean = left_diff * h_reciprocal * W + var = left_diff * right_diff * h_reciprocal + noise = parent._randn(parent._W_seed) + left_W = mean + math.sqrt(var) * noise + + if self._is_left: + out_W = left_W + else: + out_W = W - left_W + out_H = None + + self._top._increment_and_space_time_levy_area_cache[self] = (out_W, out_H) + return out_W, out_H + + def _randn(self, seed): + # We generate random noise deterministically wrt some seed; this seed is determined by the generator. + # This means that if we drop out of the cache, then we'll create the same random noise next time, as we still + # have the generator. + size = self._top._size + return _randn(size, self._top._dtype, self._top._device, seed) + + def _a_seed(self): + return self._parent._left_a_seed if self._is_left else self._parent._right_a_seed + + def _randn_levy(self): + size = (*self._top._size, *self._top._size[-1:]) + return _randn(size, self._top._dtype, self._top._device, self._a_seed()) + + ######################################## + # Locate an interval in the hierarchy # + ######################################## + # + # The other important piece of this construction is a way to locate any given interval within the binary tree + # hierarchy. (This is typically the slightly slower part, actually, so if you want to speed things up then this is + # the bit to target.) + # + # loc finds the interval [ta, tb] - and creates it in the appropriate place (as a child of some larger interval) if + # it doesn't already exist. As in principle we may request an interval that covers multiple existing intervals, then + # in fact the interval [ta, tb] is returned as an ordered list of existing subintervals. + # + # It calls _loc, which operates recursively. See _loc for more details on how the search works. + + def _loc(self, ta, tb): + out = [] + ta = self._top._round(ta) + tb = self._top._round(tb) + trampoline.trampoline(self._loc_inner(ta, tb, out)) + return out + + def _loc_inner(self, ta, tb, out): + # Expect to have ta < tb + + # First, we (this interval) only have jurisdiction over [self._start, self._end]. So if we're asked for + # something outside of that then we pass the buck up to our parent, who is strictly larger. + if ta < self._start or tb > self._end: + raise trampoline.TailCall(self._parent._loc_inner(ta, tb, out)) + + # If it's us that's being asked for, then we add ourselves on to out and return. + if ta == self._start and tb == self._end: + out.append(self) + return + + # If we've got this far then we know that it's an interval that's within our jurisdiction, and that it's not us. + # So next we check if it's up to us to figure out, or up to our children. + if self._midway is None: + # It's up to us. Create subintervals (_split) if appropriate. + if ta == self._start: + self._split(tb) + raise trampoline.TailCall(self._left_child._loc_inner(ta, tb, out)) + # implies ta > self._start + self._split(ta) + # Query our (newly created) right_child: if tb == self._end then our right child will be the result, and it + # will tell us so. But if tb < self._end then our right_child will need to make another split of its own. + raise trampoline.TailCall(self._right_child._loc_inner(ta, tb, out)) + + # If we're here then we have children: self._midway is not None + if tb <= self._midway: + # Strictly our left_child's problem + raise trampoline.TailCall(self._left_child._loc_inner(ta, tb, out)) + if ta >= self._midway: + # Strictly our right_child's problem + raise trampoline.TailCall(self._right_child._loc_inner(ta, tb, out)) + # It's a problem for both of our children: the requested interval overlaps our midpoint. Call the left_child + # first (to append to out in the correct order), then call our right child. + # (Implies ta < self._midway < tb) + yield self._left_child._loc_inner(ta, self._midway, out) + raise trampoline.TailCall(self._right_child._loc_inner(self._midway, tb, out)) + + def _set_spawn_key_and_depth(self): + self._spawn_key = 2 * self._parent._spawn_key + (0 if self._is_left else 1) + self._depth = self._parent._depth + 1 + + def _split(self, midway): + if self._top._halfway_tree: + self._split_exact(0.5 * (self._end + self._start)) + # self._midway is now the rounded halfway point. + if midway > self._midway: + self._right_child._split(midway) + elif midway < self._midway: + self._left_child._split(midway) + else: + self._split_exact(midway) + + def _split_exact(self, midway): # Create two children + self._midway = self._top._round(midway) + # Use splittable PRNGs to generate noise. + self._set_spawn_key_and_depth() + generator = np.random.SeedSequence(entropy=self._top._entropy, + spawn_key=(self._spawn_key, self._depth), + pool_size=self._top._pool_size) + self._W_seed, self._H_seed, self._left_a_seed, self._right_a_seed = generator.generate_state(4) + + self._left_child = _Interval(start=self._start, + end=midway, + parent=self, + is_left=True, + top=self._top) + self._right_child = _Interval(start=midway, + end=self._end, + parent=self, + is_left=False, + top=self._top) + + +class BrownianInterval(brownian_base.BaseBrownian, _Interval): + """Brownian interval with fixed entropy. + + Computes increments (and optionally Levy area). + + To use: + >>> bm = BrownianInterval(t0=0.0, t1=1.0, size=(4, 1), device='cuda') + >>> bm(0., 0.5) + tensor([[ 0.0733], + [-0.5692], + [ 0.1872], + [-0.3889]], device='cuda:0') + """ + + __slots__ = ( + # Inputs + '_size', + '_dtype', + '_device', + '_entropy', + '_levy_area_approximation', + '_dt', + '_tol', + '_pool_size', + '_cache_size', + '_halfway_tree', + # Quantisation + '_round', + # Caching, searching and computing values + '_increment_and_space_time_levy_area_cache', + '_last_interval', + '_have_H', + '_have_A', + '_w_h', + '_top_a_seed', + # Dependency tree creation + '_average_dt', + '_tree_dt', + '_num_evaluations' + ) + + def __init__(self, + t0: Optional[Scalar] = 0., + t1: Optional[Scalar] = 1., + size: Optional[Tuple[int, ...]] = None, + dtype: Optional[torch.dtype] = None, + device: Optional[Union[str, torch.device]] = None, + entropy: Optional[int] = None, + dt: Optional[Scalar] = None, + tol: Scalar = 0., + pool_size: int = 8, + cache_size: Optional[int] = 45, + halfway_tree: bool = False, + levy_area_approximation: str = LEVY_AREA_APPROXIMATIONS.none, + W: Optional[Tensor] = None, + H: Optional[Tensor] = None): + """Initialize the Brownian interval. + + Args: + t0 (float or Tensor): Initial time. + t1 (float or Tensor): Terminal time. + size (tuple of int): The shape of each Brownian sample. + If zero dimensional represents a scalar Brownian motion. + If one dimensional represents a batch of scalar Brownian motions. + If >two dimensional the last dimension represents the size of a + a multidimensional Brownian motion, and all previous dimensions + represent batch dimensions. + dtype (torch.dtype): The dtype of each Brownian sample. + Defaults to the PyTorch default. + device (str or torch.device): The device of each Brownian sample. + Defaults to the CPU. + entropy (int): Global seed, defaults to `None` for random entropy. + levy_area_approximation (str): Whether to also approximate Levy + area. Defaults to 'none'. Valid options are 'none', + 'space-time', 'davie' or 'foster', corresponding to different + approximation types. + This is needed for some higher-order SDE solvers. + dt (float or Tensor): The expected average step size of the SDE + solver. Set it if you know it (e.g. when using a fixed-step + solver); else it will be estimated from the first few queries. + This is used to set up the data structure such that it is + efficient to query at these intervals. + tol (float or Tensor): What tolerance to resolve the Brownian motion + to. Must be non-negative. Defaults to zero, i.e. floating point + resolution. Usually worth setting in conjunction with + `halfway_tree`, below. + pool_size (int): Size of the pooled entropy. If you care about + statistical randomness then increasing this will help (but will + slow things down). + cache_size (int): How big a cache of recent calculations to use. + (As new calculations depend on old calculations, this speeds + things up dramatically, rather than recomputing things.) + Set this to `None` to use an infinite cache, which will be fast + but memory inefficient. + halfway_tree (bool): Whether the dependency tree (the internal data + structure) should be the dyadic tree. Defaults to `False`. + Normally, the sample path is determined by both `entropy`, + _and_ the locations and order of the query points. Setting this + to `True` will make it deterministic with respect to just + `entropy`; however this is much slower. + W (Tensor): The increment of the Brownian motion over the interval + [t0, t1]. Will be generated randomly if not provided. + H (Tensor): The space-time Levy area of the Brownian motion over the + interval [t0, t1]. Will be generated randomly if not provided. + """ + + ##################################### + # Check and normalise inputs # + ##################################### + + if not _is_scalar(t0): + raise ValueError('Initial time t0 should be a float or 0-d torch.Tensor.') + if not _is_scalar(t1): + raise ValueError('Terminal time t1 should be a float or 0-d torch.Tensor.') + if dt is not None and not _is_scalar(dt): + raise ValueError('Expected average time step dt should be a float or 0-d torch.Tensor.') + + if t0 > t1: + raise ValueError(f'Initial time {t0} should be less than terminal time {t1}.') + t0 = float(t0) + t1 = float(t1) + if dt is not None: + dt = float(dt) + + if halfway_tree: + if tol <= 0.: + raise ValueError("`tol` should be positive.") + if dt is not None: + raise ValueError("`dt` is not used and should be set to `None` if `halfway_tree` is True.") + else: + if tol < 0.: + raise ValueError("`tol` should be non-negative.") + + size, dtype, device = _check_tensor_info(W, H, size=size, dtype=dtype, device=device) + + # Let numpy dictate randomness, so we have fewer seeds to set for reproducibility. + if entropy is None: + entropy = np.random.randint(0, 2 ** 31 - 1) + + if levy_area_approximation not in LEVY_AREA_APPROXIMATIONS: + raise ValueError(f"`levy_area_approximation` must be one of {LEVY_AREA_APPROXIMATIONS}, but got " + f"'{levy_area_approximation}'.") + + ##################################### + # Record inputs # + ##################################### + + self._size = size + self._dtype = dtype + self._device = device + self._entropy = entropy + self._levy_area_approximation = levy_area_approximation + self._dt = dt + self._tol = tol + self._pool_size = pool_size + self._cache_size = cache_size + self._halfway_tree = halfway_tree + + ##################################### + # A miscellany of other things # + ##################################### + + # We keep a cache of recent queries, and their results. This is very important for speed, so that we don't + # recurse all the way up to the top every time we have a query. + if cache_size is None: + self._increment_and_space_time_levy_area_cache = {} + elif cache_size == 0: + self._increment_and_space_time_levy_area_cache = _EmptyDict() + else: + self._increment_and_space_time_levy_area_cache = _LRUDict(max_size=cache_size) + + # We keep track of the most recently queried interval, and start searching for the next interval from that + # element of the binary tree. This is because subsequent queries are likely to be near the most recent query. + self._last_interval = self + + # Precompute these as we don't want to spend lots of time checking strings in hot loops. + self._have_H = self._levy_area_approximation in (LEVY_AREA_APPROXIMATIONS.space_time, + LEVY_AREA_APPROXIMATIONS.davie, + LEVY_AREA_APPROXIMATIONS.foster) + self._have_A = self._levy_area_approximation in (LEVY_AREA_APPROXIMATIONS.davie, + LEVY_AREA_APPROXIMATIONS.foster) + + # If we like we can quantise what level we want to compute the Brownian motion to. + if tol == 0.: + self._round = lambda x: x + else: + ndigits = -int(math.log10(tol)) + self._round = lambda x: round(x, ndigits) + + # Initalise as _Interval. + # (Must come after _round but before _w_h) + super(BrownianInterval, self).__init__(start=t0, + end=t1, + parent=None, + is_left=None, + top=self) + + # Set the global increment and space-time Levy area + generator = np.random.SeedSequence(entropy=entropy, pool_size=pool_size) + initial_W_seed, initial_H_seed, top_a_seed = generator.generate_state(3) + if W is None: + W = self._randn(initial_W_seed) * math.sqrt(t1 - t0) + else: + _assert_floating_tensor('W', W) + if H is None: + H = self._randn(initial_H_seed) * math.sqrt((t1 - t0) / 12) + else: + _assert_floating_tensor('H', H) + self._w_h = (W, H) + self._top_a_seed = top_a_seed + + if not self._halfway_tree: + # We create a binary tree dependency between the points. If we don't do this then the forward pass is still + # efficient at O(N), but we end up with a dependency chain stretching along the interval [t0, t1], making + # the backward pass O(N^2). By setting up a dependency tree of depth relative to `dt` and `cache_size` we + # can instead make both directions O(N log N). + self._average_dt = 0 + self._tree_dt = t1 - t0 + self._num_evaluations = -100 # start off with a warmup period to get a decent estimate of the average + if dt is not None: + # Create the dependency tree based on the supplied hint `dt`. + self._create_dependency_tree(dt) + # If dt is None, then create the dependency tree based on observed statistics of query points. (In __call__) + + # Effectively permanently store our increment and space-time Levy area in the cache. + def _increment_and_space_time_levy_area(self): + return self._w_h + yield # make it a generator + + def _a_seed(self): + return self._top_a_seed + + def _set_spawn_key_and_depth(self): + self._spawn_key = 0 + self._depth = 0 + + def __call__(self, ta, tb=None, return_U=False, return_A=False): + if tb is None: + warnings.warn(f"{self.__class__.__name__} is optimised for interval-based queries, not point evaluation.") + ta, tb = self._start, ta + tb_name = 'ta' + else: + tb_name = 'tb' + ta = float(ta) + tb = float(tb) + if ta < self._start: + warnings.warn(f"Should have ta>=t0 but got ta={ta} and t0={self._start}.") + ta = self._start + if tb < self._start: + warnings.warn(f"Should have {tb_name}>=t0 but got {tb_name}={tb} and t0={self._start}.") + tb = self._start + if ta > self._end: + warnings.warn(f"Should have ta<=t1 but got ta={ta} and t1={self._end}.") + ta = self._end + if tb > self._end: + warnings.warn(f"Should have {tb_name}<=t1 but got {tb_name}={tb} and t1={self._end}.") + tb = self._end + if ta > tb: + raise RuntimeError(f"Query times ta={ta:.3f} and tb={tb:.3f} must respect ta <= tb.") + + if ta == tb: + W = torch.zeros(self._size, dtype=self._dtype, device=self._device) + H = None + A = None + if self._have_H: + H = torch.zeros(self._size, dtype=self._dtype, device=self._device) + if self._have_A: + size = (*self._size, *self._size[-1:]) # not self._size[-1] as that may not exist + A = torch.zeros(size, dtype=self._dtype, device=self._device) + else: + if self._dt is None and not self._halfway_tree: + self._num_evaluations += 1 + # We start off with "negative" num evaluations, to give us a small warm-up period at the start. + if self._num_evaluations > 0: + # Compute average step size so far + dt = tb - ta + self._average_dt = (dt + self._average_dt * (self._num_evaluations - 1)) / self._num_evaluations + if self._average_dt < 0.5 * self._tree_dt: + # If 'dt' wasn't specified, then check the average interval length against the size of the + # bottom of the dependency tree. If we're below halfway then refine the tree by splitting all + # the bottom pieces into two. + self._create_dependency_tree(dt) + + # Find the intervals that correspond to the query. We start our search at the last interval we accessed in + # the binary tree, as it's likely that the next query will come nearby. + intervals = self._last_interval._loc(ta, tb) + # Ideally we'd keep track of intervals[0] on the backward pass. Practically speaking len(intervals) tends to + # be 1 or 2 almost always so this isn't a huge deal. + self._last_interval = intervals[-1] + + W, H, A = intervals[0]._increment_and_levy_area() + if len(intervals) > 1: + # If we have multiple intervals then add up their increments and Levy areas. + + for interval in intervals[1:]: + Wi, Hi, Ai = interval._increment_and_levy_area() + if self._have_H: + # Aggregate H: + # Given s < u < t, then + # H_{s,t} = (term1 + term2) / (t - s) + # where + # term1 = (t - u) * (H_{u, t} + W_{s, u} / 2) + # term2 = (u - s) * (H_{s, u} - W_{u, t} / 2) + term1 = (interval._end - interval._start) * (Hi + 0.5 * W) + term2 = (interval._start - ta) * (H - 0.5 * Wi) + H = (term1 + term2) / (interval._end - ta) + if self._have_A and len(self._size) not in (0, 1): + # If len(self._size) in (0, 1) then we treat our scalar / single dimension as a batch + # dimension, so we have zero Levy area. (And these unsqueezes will result in a tensor of shape + # (batch, batch) which is wrong.) + + # Let B_{x, y} = \int_x^y W^1_{s,u} dW^2_u. + # Then + # B_{s, t} = \int_s^t W^1_{s,u} dW^2_u + # = \int_s^v W^1_{s,u} dW^2_u + \int_v^t W^1_{s,v} dW^2_u + \int_v^t W^1_{v,u} dW^2_u + # = B_{s, v} + W^1_{s, v} W^2_{v, t} + B_{v, t} + # + # A is now the antisymmetric part of B, which gives the formula below. + A = A + Ai + 0.5 * (W.unsqueeze(-1) * Wi.unsqueeze(-2) - Wi.unsqueeze(-1) * W.unsqueeze(-2)) + W = W + Wi + + U = None + if self._have_H: + U = _H_to_U(W, H, tb - ta) + + if return_U: + if return_A: + return W, U, A + else: + return W, U + else: + if return_A: + return W, A + else: + return W + + def _create_dependency_tree(self, dt): + # For safety we take a min with 100: if people take very large cache sizes then this would then break the + # logarithmic into linear, which causes RecursionErrors. + if self._cache_size is None: # cache_size=None corresponds to infinite cache. + cache_size = 100 + else: + cache_size = min(self._cache_size, 100) + + self._tree_dt = min(self._tree_dt, dt) + # Rationale: We are prepared to hold `cache_size` many things in memory, so when making steps of size `dt` + # then we can afford to have the intervals at the bottom of our binary tree be of size `dt * cache_size`. + # For safety we then make this a bit smaller by multiplying by 0.8. + piece_length = self._tree_dt * cache_size * 0.8 + + def _set_points(interval): + start = interval._start + end = interval._end + if end - start > piece_length: + midway = (end + start) / 2 + interval._loc(start, midway) + _set_points(interval._left_child) + _set_points(interval._right_child) + + _set_points(self) + + def __repr__(self): + if self._dt is None: + dt = None + else: + dt = f"{self._dt:.3f}" + return (f"{self.__class__.__name__}(" + f"t0={self._start:.3f}, " + f"t1={self._end:.3f}, " + f"size={self._size}, " + f"dtype={self._dtype}, " + f"device={repr(self._device)}, " + f"entropy={self._entropy}, " + f"dt={dt}, " + f"tol={self._tol}, " + f"pool_size={self._pool_size}, " + f"cache_size={self._cache_size}, " + f"levy_area_approximation={repr(self._levy_area_approximation)}" + f")") + + def display_binary_tree(self): + stack = [(self, 0)] + out = [] + while len(stack): + elem, depth = stack.pop() + out.append(" " * depth + f"({elem._start}, {elem._end})") + if elem._midway is not None: + stack.append((elem._right_child, depth + 1)) + stack.append((elem._left_child, depth + 1)) + print("\n".join(out)) + + @property + def shape(self): + return self._size + + @property + def dtype(self): + return self._dtype + + @property + def device(self): + return self._device + + @property + def entropy(self): + return self._entropy + + @property + def levy_area_approximation(self): + return self._levy_area_approximation + + @property + def dt(self): + return self._dt + + @property + def tol(self): + return self._tol + + @property + def pool_size(self): + return self._pool_size + + @property + def cache_size(self): + return self._cache_size + + @property + def halfway_tree(self): + return self._halfway_tree + + def size(self): + return self._size + + +class BrownianTree(brownian_base.BaseBrownian): + """Brownian tree with fixed entropy. + + Useful when the map from entropy -> Brownian motion shouldn't depend on the + locations and order of the query points. (As the usual BrownianInterval + does - note that BrownianTree is slower as a result though.) + + To use: + >>> bm = BrownianTree(t0=0.0, w0=torch.zeros(4, 1)) + >>> bm(0., 0.5) + tensor([[ 0.0733], + [-0.5692], + [ 0.1872], + [-0.3889]], device='cuda:0') + """ + + def __init__(self, t0: Scalar, + w0: Tensor, + t1: Optional[Scalar] = None, + w1: Optional[Tensor] = None, + entropy: Optional[int] = None, + tol: float = 1e-6, + pool_size: int = 24, + cache_depth: int = 9, + safety: Optional[float] = None): + """Initialize the Brownian tree. + + The random value generation process exploits the parallel random number paradigm and uses + `numpy.random.SeedSequence`. The default generator is PCG64 (used by `default_rng`). + + Arguments: + t0: Initial time. + w0: Initial state. + t1: Terminal time. + w1: Terminal state. + entropy: Global seed, defaults to `None` for random entropy. + tol: Error tolerance before the binary search is terminated; the search depth ~ log2(tol). + pool_size: Size of the pooled entropy. This parameter affects the query speed significantly. + cache_depth: Unused; deprecated. + safety: Unused; deprecated. + """ + + if t1 is None: + t1 = t0 + 1 + if w1 is None: + W = None + else: + W = w1 - w0 + self._w0 = w0 + self._interval = BrownianInterval(t0=t0, + t1=t1, + size=w0.shape, + dtype=w0.dtype, + device=w0.device, + entropy=entropy, + tol=tol, + pool_size=pool_size, + halfway_tree=True, + W=W) + super(BrownianTree, self).__init__() + + def __call__(self, t, tb=None, return_U=False, return_A=False): + # Deliberately called t rather than ta, for backward compatibility + out = self._interval(t, tb, return_U=return_U, return_A=return_A) + if tb is None and not return_U and not return_A: + out = out + self._w0 + return out + + def __repr__(self): + return f"{self.__class__.__name__}(interval={self._interval})" + + @property + def dtype(self): + return self._interval.dtype + + @property + def device(self): + return self._interval.device + + @property + def shape(self): + return self._interval.shape + + @property + def levy_area_approximation(self): + return self._interval.levy_area_approximation + + +def brownian_interval_like(y: Tensor, + t0: Optional[Scalar] = 0., + t1: Optional[Scalar] = 1., + size: Optional[Tuple[int, ...]] = None, + dtype: Optional[torch.dtype] = None, + device: Optional[Union[str, torch.device]] = None, + **kwargs): + """Returns a BrownianInterval object with the same size, device, and dtype as a given tensor.""" + size = y.shape if size is None else size + dtype = y.dtype if dtype is None else dtype + device = y.device if device is None else device + return brownian_interval.BrownianInterval(t0=t0, t1=t1, size=size, dtype=dtype, device=device, **kwargs) \ No newline at end of file -- Gitee From 57b1782285922406b26c6e71434bfe232bde9bd3 Mon Sep 17 00:00:00 2001 From: zhoufan2956 Date: Wed, 15 Jan 2025 16:36:48 +0800 Subject: [PATCH 2/6] add_stable_audio --- .../stable_audio_open_1.0/stable_audio_open/prompts/prompts.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_audio_open_1.0/stable_audio_open/prompts/prompts.txt b/MindIE/MindIE-Torch/built-in/foundation/stable_audio_open_1.0/stable_audio_open/prompts/prompts.txt index e977cc3ad7..523a9bcbcf 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/stable_audio_open_1.0/stable_audio_open/prompts/prompts.txt +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_audio_open_1.0/stable_audio_open/prompts/prompts.txt @@ -2,5 +2,5 @@ Berlin techno, rave, drum machine, kick, ARP synthesizer, dark, moody, hypnotic, Uplifting acoustic loop. 120 BPM. Disco, Driving Drum Machine, Synthesizer, Bass, Piano, Guitars, Instrumental, Clubby, Euphoric, Chicago, New York, 115 BPM. Warm arpeggios on an analog synthesizer with a gradually rising filter cutoff and a reverb tail. -Blackbird song, summer, dusj in the forest. +Blackbird song, summer, dusk in the forest. Rock beat played in a treated studio, session drumming on an acoustic kit. \ No newline at end of file -- Gitee From 523b2e38aadc371d176344aee737402ff03135cb Mon Sep 17 00:00:00 2001 From: zhoufan2956 Date: Thu, 16 Jan 2025 10:11:26 +0800 Subject: [PATCH 3/6] add_stable_audio --- .../inference_stableaudio.py | 28 ++++++++++--------- .../models/stable_audio_transformer.py | 2 +- .../scheduling_cosine_dpmsolver_multistep.py | 6 ++-- .../schedulers/scheduling_dpmsolver_sde.py | 3 +- .../schedulers/scheduling_utils.py | 14 ++++++---- 5 files changed, 31 insertions(+), 22 deletions(-) diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_audio_open_1.0/stable_audio_open/inference_stableaudio.py b/MindIE/MindIE-Torch/built-in/foundation/stable_audio_open_1.0/stable_audio_open/inference_stableaudio.py index 08801a0082..7204775809 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/stable_audio_open_1.0/stable_audio_open/inference_stableaudio.py +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_audio_open_1.0/stable_audio_open/inference_stableaudio.py @@ -1,9 +1,10 @@ -import torch -import torch_npu import time import json import os import argparse + +import torch +import torch_npu import soundfile as sf from safetensors.torch import load_file @@ -74,6 +75,7 @@ def parse_arguments(): ) return parser.parse_args() + def main(): args = parse_arguments() save_dir = args.save_dir @@ -83,19 +85,18 @@ def main(): torch_npu.npu.set_device(args.device) if args.seed != -1: torch.manual_seed(args.seed) - latents = torch.randn(1, 64, 1024, dtype=torch.float16,device="cpu") - with open(args.model + "/vae/config.json", "r", encoding="utf-8") as reader: + latents = torch.randn(1, 64, 1024, dtype=torch.float16, device="cpu") + with open(os.path.join(args.model, "vae", "config.json"), "r", encoding="utf-8") as reader: data = reader.read() json_data = json.loads(data) init_dict = {key: json_data[key] for key in json_data} vae = AutoencoderOobleck(**init_dict) - vae.load_state_dict(load_file(args.model + "/vae/diffusion_pytorch_model.safetensors"), strict=False) - - tokenizer = T5TokenizerFast.from_pretrained(args.model + "/tokenizer") - text_encoder = T5EncoderModel.from_pretrained(args.model + "/text_encoder") - projection_model = StableAudioProjectionModel.from_pretrained(args.model + "/projection_model") - audio_dit = StableAudioDiTModel.from_pretrained(args.model + "/transformer") - scheduler = CosineDPMSolverMultistepScheduler.from_pretrained(args.model + "/scheduler") + vae.load_state_dict(load_file(os.path.join(args.model, "vae", "diffusion_pytorch_model.safetensors")), strict=False) + tokenizer = T5TokenizerFast.from_pretrained(os.path.join(args.model, "tokenizer")) + text_encoder = T5EncoderModel.from_pretrained(os.path.join(args.model, "text_encoder")) + projection_model = StableAudioProjectionModel.from_pretrained(os.path.join(args.model, "projection_model")) + audio_dit = StableAudioDiTModel.from_pretrained(os.path.join(args.model, "transformer")) + scheduler = CosineDPMSolverMultistepScheduler.from_pretrained(os.path.join(args.model, "scheduler")) npu_stream = torch_npu.npu.Stream() vae = vae.to("npu").to(torch.float16).eval() @@ -129,9 +130,10 @@ def main(): end = time.time() if i > skip - 1: total_time += end - begin - prompts_num = i+1 + prompts_num = i + 1 output = audio[0].T.float().cpu().numpy() - sf.write(args.save_dir + "/audio_by_prompt" + str(prompts_num) + ".wav", output, pipe.vae.sampling_rate) + file_path = os.path.join(args.save_dir, f"audio_by_prompt{prompts_num}.wav") + sf.write(file_path, output, pipe.vae.sampling_rate) if prompts_num > skip: average_time = total_time / (prompts_num-skip) else: diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_audio_open_1.0/stable_audio_open/stableaudio/models/stable_audio_transformer.py b/MindIE/MindIE-Torch/built-in/foundation/stable_audio_open_1.0/stable_audio_open/stableaudio/models/stable_audio_transformer.py index cd8a822c1b..2b02ac0a2e 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/stable_audio_open_1.0/stable_audio_open/stableaudio/models/stable_audio_transformer.py +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_audio_open_1.0/stable_audio_open/stableaudio/models/stable_audio_transformer.py @@ -411,7 +411,7 @@ class StableAudioDiTModel(ModelMixin, ConfigMixin): ) return hidden_states - def _convert_deprecated_attention_blocks(self, state_dict: OrderedDict)->None: + def _convert_deprecated_attention_blocks(self, state_dict: OrderedDict) ->None: for i in range(self.num_layers): self_q_weight = state_dict.pop(f"transformer_blocks.{i}.attn1.to_q.weight", None) self_k_weight = state_dict.pop(f"transformer_blocks.{i}.attn1.to_k.weight", None) diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_audio_open_1.0/stable_audio_open/stableaudio/schedulers/scheduling_cosine_dpmsolver_multistep.py b/MindIE/MindIE-Torch/built-in/foundation/stable_audio_open_1.0/stable_audio_open/stableaudio/schedulers/scheduling_cosine_dpmsolver_multistep.py index 7a33eae2b9..d4e7805d13 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/stable_audio_open_1.0/stable_audio_open/stableaudio/schedulers/scheduling_cosine_dpmsolver_multistep.py +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_audio_open_1.0/stable_audio_open/stableaudio/schedulers/scheduling_cosine_dpmsolver_multistep.py @@ -352,7 +352,8 @@ class CosineDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): lambda_s = torch.log(alpha_s) - torch.log(sigma_s) h = lambda_t - lambda_s - assert noise is not None + if noise is None: + raise ValueError("noise must not be None") x_t = ( (sigma_t / sigma_s * torch.exp(-h)) * sample + (alpha_t * (1 - torch.exp(-2.0 * h))) * model_output @@ -401,7 +402,8 @@ class CosineDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): D0, D1 = m0, (1.0 / r0) * (m0 - m1) # sde-dpmsolver++ - assert noise is not None + if noise is None: + raise ValueError("noise must not be None") if self.config.solver_type == "midpoint": x_t = ( (sigma_t / sigma_s0 * torch.exp(-h)) * sample diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_audio_open_1.0/stable_audio_open/stableaudio/schedulers/scheduling_dpmsolver_sde.py b/MindIE/MindIE-Torch/built-in/foundation/stable_audio_open_1.0/stable_audio_open/stableaudio/schedulers/scheduling_dpmsolver_sde.py index 2193d4175d..8533b082f5 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/stable_audio_open_1.0/stable_audio_open/stableaudio/schedulers/scheduling_dpmsolver_sde.py +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_audio_open_1.0/stable_audio_open/stableaudio/schedulers/scheduling_dpmsolver_sde.py @@ -28,7 +28,8 @@ class BatchedBrownianTree: seed = torch.randint(0, 2**63 - 1, []).item() self.batched = True try: - assert len(seed) == x.shape[0] + if len(seed) == x.shape[0]: + raise ValueError(f"len(seed) should equal to x.shape[0], but got{len(seed)}") w0 = w0[0] except TypeError: seed = [seed] diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_audio_open_1.0/stable_audio_open/stableaudio/schedulers/scheduling_utils.py b/MindIE/MindIE-Torch/built-in/foundation/stable_audio_open_1.0/stable_audio_open/stableaudio/schedulers/scheduling_utils.py index 9bc31717cf..8340167ec1 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/stable_audio_open_1.0/stable_audio_open/stableaudio/schedulers/scheduling_utils.py +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_audio_open_1.0/stable_audio_open/stableaudio/schedulers/scheduling_utils.py @@ -14,8 +14,8 @@ # limitations under the License. import math -import trampoline import warnings +import trampoline import numpy as np import torch @@ -534,11 +534,15 @@ class BrownianInterval(brownian_base.BaseBrownian, _Interval): LEVY_AREA_APPROXIMATIONS.foster) # If we like we can quantise what level we want to compute the Brownian motion to. - if tol == 0.: - self._round = lambda x: x + if math.isclose(tol, 0., rel_tol=1e-5): + def round_func(x): + return x else: ndigits = -int(math.log10(tol)) - self._round = lambda x: round(x, ndigits) + def round_func(x): + return round(x, ndigits) + + self._round = round_func # Initalise as _Interval. # (Must come after _round but before _w_h) @@ -734,7 +738,7 @@ class BrownianInterval(brownian_base.BaseBrownian, _Interval): def display_binary_tree(self): stack = [(self, 0)] out = [] - while len(stack): + while stack: elem, depth = stack.pop() out.append(" " * depth + f"({elem._start}, {elem._end})") if elem._midway is not None: -- Gitee From 5e099a7f57ae64ba8f6bff2c21f0335f79f61740 Mon Sep 17 00:00:00 2001 From: zhoufan2956 Date: Thu, 16 Jan 2025 10:35:14 +0800 Subject: [PATCH 4/6] add_stable_audio --- .../stable_audio_open/inference_stableaudio.py | 2 +- .../stableaudio/schedulers/scheduling_utils.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_audio_open_1.0/stable_audio_open/inference_stableaudio.py b/MindIE/MindIE-Torch/built-in/foundation/stable_audio_open_1.0/stable_audio_open/inference_stableaudio.py index 7204775809..cca5c130ce 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/stable_audio_open_1.0/stable_audio_open/inference_stableaudio.py +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_audio_open_1.0/stable_audio_open/inference_stableaudio.py @@ -135,7 +135,7 @@ def main(): file_path = os.path.join(args.save_dir, f"audio_by_prompt{prompts_num}.wav") sf.write(file_path, output, pipe.vae.sampling_rate) if prompts_num > skip: - average_time = total_time / (prompts_num-skip) + average_time = total_time / (prompts_num - skip) else: raise ValueError("Infer average time skip first two prompts, ensure that prompts.txt \ contains more than three prompts") diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_audio_open_1.0/stable_audio_open/stableaudio/schedulers/scheduling_utils.py b/MindIE/MindIE-Torch/built-in/foundation/stable_audio_open_1.0/stable_audio_open/stableaudio/schedulers/scheduling_utils.py index 8340167ec1..e4734b052f 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/stable_audio_open_1.0/stable_audio_open/stableaudio/schedulers/scheduling_utils.py +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_audio_open_1.0/stable_audio_open/stableaudio/schedulers/scheduling_utils.py @@ -539,6 +539,7 @@ class BrownianInterval(brownian_base.BaseBrownian, _Interval): return x else: ndigits = -int(math.log10(tol)) + def round_func(x): return round(x, ndigits) -- Gitee From 2b4a2281365f6b462abde3087ce2d5f7c4046611 Mon Sep 17 00:00:00 2001 From: zhoufan2956 Date: Sat, 18 Jan 2025 14:08:24 +0800 Subject: [PATCH 5/6] update --- .../stable_audio_open_1.0/stable_audio_open/README.md | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_audio_open_1.0/stable_audio_open/README.md b/MindIE/MindIE-Torch/built-in/foundation/stable_audio_open_1.0/stable_audio_open/README.md index 389707e71f..195c6a14d7 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/stable_audio_open_1.0/stable_audio_open/README.md +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_audio_open_1.0/stable_audio_open/README.md @@ -1,14 +1,3 @@ ---- -pipeline_tag: text-to-audio -frameworks: - - PyTorch -license: apache-2.0 -library_name: openmind -hardwares: - - NPU -language: - - en ---- ## 一、准备运行环境 **表 1** 版本配套表 -- Gitee From e98671f43331e838e0f82d21f2cd0d06dfe22dd4 Mon Sep 17 00:00:00 2001 From: zhoufan2956 Date: Tue, 21 Jan 2025 17:36:14 +0800 Subject: [PATCH 6/6] update --- .../inference_stableaudio.py | 38 ++++++------------- .../stable_audio_open/stableaudio/__init__.py | 4 +- .../stableaudio/models/__init__.py | 2 +- 3 files changed, 14 insertions(+), 30 deletions(-) diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_audio_open_1.0/stable_audio_open/inference_stableaudio.py b/MindIE/MindIE-Torch/built-in/foundation/stable_audio_open_1.0/stable_audio_open/inference_stableaudio.py index cca5c130ce..23f3ce14ec 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/stable_audio_open_1.0/stable_audio_open/inference_stableaudio.py +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_audio_open_1.0/stable_audio_open/inference_stableaudio.py @@ -8,12 +8,8 @@ import torch_npu import soundfile as sf from safetensors.torch import load_file -from transformers import T5TokenizerFast, T5EncoderModel from stableaudio import ( StableAudioPipeline, - StableAudioDiTModel, - StableAudioProjectionModel, - CosineDPMSolverMultistepScheduler, AutoencoderOobleck, ) @@ -82,31 +78,19 @@ def main(): if not os.path.exists(save_dir): os.makedirs(save_dir) + npu_stream = torch_npu.npu.Stream() torch_npu.npu.set_device(args.device) if args.seed != -1: torch.manual_seed(args.seed) - latents = torch.randn(1, 64, 1024, dtype=torch.float16, device="cpu") - with open(os.path.join(args.model, "vae", "config.json"), "r", encoding="utf-8") as reader: - data = reader.read() - json_data = json.loads(data) - init_dict = {key: json_data[key] for key in json_data} - vae = AutoencoderOobleck(**init_dict) - vae.load_state_dict(load_file(os.path.join(args.model, "vae", "diffusion_pytorch_model.safetensors")), strict=False) - tokenizer = T5TokenizerFast.from_pretrained(os.path.join(args.model, "tokenizer")) - text_encoder = T5EncoderModel.from_pretrained(os.path.join(args.model, "text_encoder")) - projection_model = StableAudioProjectionModel.from_pretrained(os.path.join(args.model, "projection_model")) - audio_dit = StableAudioDiTModel.from_pretrained(os.path.join(args.model, "transformer")) - scheduler = CosineDPMSolverMultistepScheduler.from_pretrained(os.path.join(args.model, "scheduler")) - - npu_stream = torch_npu.npu.Stream() - vae = vae.to("npu").to(torch.float16).eval() - text_encoder = text_encoder.to("npu").to(torch.float16).eval() - projection_model = projection_model.to("npu").to(torch.float16).eval() - audio_dit = audio_dit.to("npu").to(torch.float16).eval() + latents = torch.randn(1, 64, 1024, dtype=torch.float16, device="cpu").to("npu") - pipe = StableAudioPipeline(vae=vae, tokenizer=tokenizer, text_encoder=text_encoder, - projection_model=projection_model, transformer=audio_dit, scheduler=scheduler) - pipe.to("npu") + with open(os.path.join(args.model, "vae", "config.json")) as f: + vae_config = json.load(f) + vae = AutoencoderOobleck.from_config(vae_config) + vae.load_state_dict(load_file(os.path.join(args.model, "vae", "diffusion_pytorch_model.safetensors"))) + + pipe = StableAudioPipeline.from_pretrained(args.model, vae=vae) + pipe.to(torch.float16).to("npu") total_time = 0 prompts_num = 0 @@ -115,14 +99,14 @@ def main(): with os.fdopen(os.open(args.prompt_file, os.O_RDONLY), "r") as f: for i, prompt in enumerate(f): with torch.no_grad(): - npu_stream.synchronize() audio_end_in_s = float(args.audio_end_in_s[i]) if (len(args.audio_end_in_s) > i) else 10.0 + npu_stream.synchronize() begin = time.time() audio = pipe( prompt=prompt, negative_prompt=args.negative_prompt, num_inference_steps=args.num_inference_steps, - latents=latents.to("npu"), + latents=latents, audio_end_in_s=audio_end_in_s, use_cache=args.use_cache, ).audios diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_audio_open_1.0/stable_audio_open/stableaudio/__init__.py b/MindIE/MindIE-Torch/built-in/foundation/stable_audio_open_1.0/stable_audio_open/stableaudio/__init__.py index 2e896ae145..f484d3fceb 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/stable_audio_open_1.0/stable_audio_open/stableaudio/__init__.py +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_audio_open_1.0/stable_audio_open/stableaudio/__init__.py @@ -1,4 +1,4 @@ from diffusers.models.autoencoders.autoencoder_oobleck import AutoencoderOobleck from .pipeline import StableAudioPipeline, StableAudioProjectionModel -from .models import StableAudioDiTModel -from .schedulers.scheduling_cosine_dpmsolver_multistep import CosineDPMSolverMultistepScheduler \ No newline at end of file +from .models import StableAudioDiTModel, ModelMixin +from .schedulers.scheduling_cosine_dpmsolver_multistep import CosineDPMSolverMultistepScheduler, SchedulerMixin \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_audio_open_1.0/stable_audio_open/stableaudio/models/__init__.py b/MindIE/MindIE-Torch/built-in/foundation/stable_audio_open_1.0/stable_audio_open/stableaudio/models/__init__.py index 2c132c4286..c4170f5af4 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/stable_audio_open_1.0/stable_audio_open/stableaudio/models/__init__.py +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_audio_open_1.0/stable_audio_open/stableaudio/models/__init__.py @@ -1 +1 @@ -from .stable_audio_transformer import StableAudioDiTModel \ No newline at end of file +from .stable_audio_transformer import StableAudioDiTModel, ModelMixin \ No newline at end of file -- Gitee