diff --git a/MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/README.md b/MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..35483e07f770ec02679809e455609d175caddbfe
--- /dev/null
+++ b/MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/README.md
@@ -0,0 +1,141 @@
+---
+license: apache-2.0
+frameworks:
+ - PyTorch
+language:
+ - en
+hardwares:
+ - NPU
+---
+## 一、准备运行环境
+
+ **表 1** 版本配套表
+
+ | 配套 | 版本 | 环境准备指导 |
+ | ----- | ----- |-----|
+ | Python | 3.10.2 | - |
+ | torch | 2.1.0 | - |
+
+### 1.1 获取CANN&MindIE安装包&环境准备
+- [800I A2/800T A2](https://www.hiascend.com/developer/download/community/result?module=pt+ie+cann&product=4&model=32)
+- [环境准备指导](https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/80RC2alpha002/softwareinst/instg/instg_0001.html)
+
+### 1.2 CANN安装
+```shell
+# 增加软件包可执行权限,{version}表示软件版本号,{arch}表示CPU架构,{soc}表示昇腾AI处理器的版本。
+chmod +x ./Ascend-cann-toolkit_{version}_linux-{arch}.run
+chmod +x ./Ascend-cann-kernels-{soc}_{version}_linux.run
+# 校验软件包安装文件的一致性和完整性
+./Ascend-cann-toolkit_{version}_linux-{arch}.run --check
+./Ascend-cann-kernels-{soc}_{version}_linux.run --check
+# 安装
+./Ascend-cann-toolkit_{version}_linux-{arch}.run --install
+./Ascend-cann-kernels-{soc}_{version}_linux.run --install
+
+# 设置环境变量
+source /usr/local/Ascend/ascend-toolkit/set_env.sh
+```
+
+### 1.3 MindIE安装
+```shell
+# 增加软件包可执行权限,{version}表示软件版本号,{arch}表示CPU架构。
+chmod +x ./Ascend-mindie_${version}_linux-${arch}.run
+./Ascend-mindie_${version}_linux-${arch}.run --check
+
+# 方式一:默认路径安装
+./Ascend-mindie_${version}_linux-${arch}.run --install
+# 设置环境变量
+cd /usr/local/Ascend/mindie && source set_env.sh
+
+# 方式二:指定路径安装
+./Ascend-mindie_${version}_linux-${arch}.run --install-path=${AieInstallPath}
+# 设置环境变量
+cd ${AieInstallPath}/mindie && source set_env.sh
+```
+
+### 1.4 Torch_npu安装
+安装pytorch框架 版本2.1.0
+[安装包下载](https://download.pytorch.org/whl/cpu/torch/)
+
+使用pip安装
+```shell
+# {version}表示软件版本号,{arch}表示CPU架构。
+pip install torch-${version}-cp310-cp310-linux_${arch}.whl
+```
+下载 pytorch_v{pytorchversion}_py{pythonversion}.tar.gz
+```shell
+tar -xzvf pytorch_v{pytorchversion}_py{pythonversion}.tar.gz
+# 解压后,会有whl包
+pip install torch_npu-{pytorchversion}.xxxx.{arch}.whl
+```
+
+### 1.5 安装所需依赖。
+```shell
+pip3 install -r requirements.txt
+```
+
+## 二、下载本仓库
+
+### 2.1 下载到本地
+```shell
+ git clone https://modelers.cn/MindIE/CogVideoX-5b.git
+```
+
+## 三、CogVideoX-5b使用
+
+### 3.1 权重及配置文件说明
+1. 下载CogVideoX-5b权重:(scheduler、text_encoder、tokenizer、transformer、vae,5个模型的配置文件及权重)
+```shell
+ git clone https://modelers.cn/AI-Research/CogVideoX-5B.git
+```
+2. 各模型的配置文件、权重文件的层级样例如下所示。
+```commandline
+|----CogVideoX-5b
+| |---- model_index.json
+| |---- scheduler
+| | |---- scheduler_config.json
+| |---- text_encoder
+| | |---- config.json
+| | |---- 模型权重
+| |---- tokenizer
+| | |---- config.json
+| | |---- 模型权重
+| |---- transformer
+| | |---- config.json
+| | |---- 模型权重
+| |---- vae
+| | |---- config.json
+| | |---- 模型权重
+```
+
+### 3.2 单卡单prompt功能测试
+设置权重路径:
+```shell
+model_path='data/CogVideoX-5b'
+```
+
+执行命令:
+```shell
+export CPU_AFFINITY_CONF=1
+export HCCL_OP_EXPANSION_MODE="AIV"
+TASK_QUEUE_ENABLE=2 ASCEND_RT_VISIBLE_DEVICES=0 torchrun --master_port=2002 --nproc_per_node=1 inference.py\
+ --prompt "A dog" \
+ --model_path ${model_path} \
+ --num_frames 48 \
+ --width 720 \
+ --height 480 \
+ --fps 8 \
+ --num_inference_steps 50
+```
+参数说明:
+- CPU_AFFINITY_CONF=1:环境变量,绑核。
+- HCCL_OP_EXPANSION_MODE="AIV":环境变量,通信算子编排。
+- TASK_QUEUE_ENABLE=2:开启二级流水。
+- ASCEND_RT_VISIBLE_DEVICES=0:device id,可设置其他卡数。
+- prompt:用于视频生成的文字描述提示。
+- model_path:权重路径,包含scheduler、text_encoder、tokenizer、transformer、vae,5个模型的配置文件及权重。
+- num_frames:生成视频的帧数。
+- width:生成视频的分辨率,宽。
+- height:生成视频的分辨率,高。
+- fps:生成视频的帧率,默认值为8。
+- num_inference_steps:推理迭代步数,默认值为50。
diff --git a/MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/cogvideox_5b/__init__.py b/MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/cogvideox_5b/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..896b34e71a1edabde2afe4312052c37176589d99
--- /dev/null
+++ b/MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/cogvideox_5b/__init__.py
@@ -0,0 +1,3 @@
+from .pipelines import CogVideoXPipeline
+from .models import CogVideoXTransformer3DModel
+from .utils import get_world_size,get_rank,all_gather
\ No newline at end of file
diff --git a/MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/cogvideox_5b/models/__init__.py b/MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/cogvideox_5b/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a267e101cd0c03bcc4f076ed254a02309fb22712
--- /dev/null
+++ b/MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/cogvideox_5b/models/__init__.py
@@ -0,0 +1 @@
+from .transformers import CogVideoXTransformer3DModel
diff --git a/MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/cogvideox_5b/models/activations.py b/MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/cogvideox_5b/models/activations.py
new file mode 100644
index 0000000000000000000000000000000000000000..7cd6938b225b9499d6d09ccf6a9faebfee2e8d49
--- /dev/null
+++ b/MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/cogvideox_5b/models/activations.py
@@ -0,0 +1,165 @@
+# coding=utf-8
+# Copyright 2024 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+from diffusers.utils import deprecate
+from diffusers.utils.import_utils import is_torch_npu_available
+
+
+if is_torch_npu_available():
+ import torch_npu
+
+ACTIVATION_FUNCTIONS = {
+ "swish": nn.SiLU(),
+ "silu": nn.SiLU(),
+ "mish": nn.Mish(),
+ "gelu": nn.GELU(),
+ "relu": nn.ReLU(),
+}
+
+
+def get_activation(act_fn: str) -> nn.Module:
+ """Helper function to get activation function from string.
+
+ Args:
+ act_fn (str): Name of activation function.
+
+ Returns:
+ nn.Module: Activation function.
+ """
+
+ act_fn = act_fn.lower()
+ if act_fn in ACTIVATION_FUNCTIONS:
+ return ACTIVATION_FUNCTIONS[act_fn]
+ else:
+ raise ValueError(f"Unsupported activation function: {act_fn}")
+
+
+class FP32SiLU(nn.Module):
+ r"""
+ SiLU activation function with input upcasted to torch.float32.
+ """
+
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
+ return F.silu(inputs.float(), inplace=False).to(inputs.dtype)
+
+
+class GELU(nn.Module):
+ r"""
+ GELU activation function with tanh approximation support with `approximate="tanh"`.
+
+ Parameters:
+ dim_in (`int`): The number of channels in the input.
+ dim_out (`int`): The number of channels in the output.
+ approximate (`str`, *optional*, defaults to `"none"`): If `"tanh"`, use tanh approximation.
+ bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
+ """
+
+ def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", bias: bool = True):
+ super().__init__()
+ self.proj = nn.Linear(dim_in, dim_out, bias=bias)
+ self.approximate = approximate
+
+ def gelu(self, gate: torch.Tensor) -> torch.Tensor:
+ if gate.device.type != "mps":
+ return F.gelu(gate, approximate=self.approximate)
+ # mps: gelu is not implemented for float16
+ return F.gelu(gate.to(dtype=torch.float32), approximate=self.approximate).to(dtype=gate.dtype)
+
+ def forward(self, hidden_states):
+ hidden_states = self.proj(hidden_states)
+ hidden_states = self.gelu(hidden_states)
+ return hidden_states
+
+
+class GEGLU(nn.Module):
+ r"""
+ A [variant](https://arxiv.org/abs/2002.05202) of the gated linear unit activation function.
+
+ Parameters:
+ dim_in (`int`): The number of channels in the input.
+ dim_out (`int`): The number of channels in the output.
+ bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
+ """
+
+ def __init__(self, dim_in: int, dim_out: int, bias: bool = True):
+ super().__init__()
+ self.proj = nn.Linear(dim_in, dim_out * 2, bias=bias)
+
+ def gelu(self, gate: torch.Tensor) -> torch.Tensor:
+ if gate.device.type != "mps":
+ return F.gelu(gate)
+ # mps: gelu is not implemented for float16
+ return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
+
+ def forward(self, hidden_states, *args, **kwargs):
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
+ deprecate("scale", "1.0.0", deprecation_message)
+ hidden_states = self.proj(hidden_states)
+ if is_torch_npu_available():
+ # using torch_npu.npu_geglu can run faster and save memory on NPU.
+ return torch_npu.npu_geglu(hidden_states, dim=-1, approximate=1)[0]
+ else:
+ hidden_states, gate = hidden_states.chunk(2, dim=-1)
+ return hidden_states * self.gelu(gate)
+
+
+class SwiGLU(nn.Module):
+ r"""
+ A [variant](https://arxiv.org/abs/2002.05202) of the gated linear unit activation function. It's similar to `GEGLU`
+ but uses SiLU / Swish instead of GeLU.
+
+ Parameters:
+ dim_in (`int`): The number of channels in the input.
+ dim_out (`int`): The number of channels in the output.
+ bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
+ """
+
+ def __init__(self, dim_in: int, dim_out: int, bias: bool = True):
+ super().__init__()
+ self.proj = nn.Linear(dim_in, dim_out * 2, bias=bias)
+ self.activation = nn.SiLU()
+
+ def forward(self, hidden_states):
+ hidden_states = self.proj(hidden_states)
+ hidden_states, gate = hidden_states.chunk(2, dim=-1)
+ return hidden_states * self.activation(gate)
+
+
+class ApproximateGELU(nn.Module):
+ r"""
+ The approximate form of the Gaussian Error Linear Unit (GELU). For more details, see section 2 of this
+ [paper](https://arxiv.org/abs/1606.08415).
+
+ Parameters:
+ dim_in (`int`): The number of channels in the input.
+ dim_out (`int`): The number of channels in the output.
+ bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
+ """
+
+ def __init__(self, dim_in: int, dim_out: int, bias: bool = True):
+ super().__init__()
+ self.proj = nn.Linear(dim_in, dim_out, bias=bias)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = self.proj(x)
+ return x * torch.sigmoid(1.702 * x)
diff --git a/MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/cogvideox_5b/models/attention.py b/MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/cogvideox_5b/models/attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..c75226dd30e0e292b372d2445b3df00479632fd8
--- /dev/null
+++ b/MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/cogvideox_5b/models/attention.py
@@ -0,0 +1,1230 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import Any, Dict, List, Optional, Tuple
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+from diffusers.utils import deprecate, logging
+from diffusers.utils.torch_utils import maybe_allow_in_graph
+from .activations import GEGLU, GELU, ApproximateGELU, FP32SiLU, SwiGLU
+from .attention_processor import Attention, JointAttnProcessor2_0
+from .embeddings import SinusoidalPositionalEmbedding
+from .normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm, SD35AdaLayerNormZeroX
+
+logger = logging.get_logger(__name__)
+
+def _chunked_feed_forward(ff: nn.Module, hidden_states: torch.Tensor, chunk_dim: int, chunk_size: int):
+ # "feed_forward_chunk_size" can be used to save memory
+ if hidden_states.shape[chunk_dim] % chunk_size != 0:
+ raise ValueError(
+ f"`hidden_states` dimension to be chunked: {hidden_states.shape[chunk_dim]} has to be divisible by chunk size: {chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
+ )
+
+ num_chunks = hidden_states.shape[chunk_dim] // chunk_size
+ ff_output = torch.cat(
+ [ff(hid_slice) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)],
+ dim=chunk_dim,
+ )
+ return ff_output
+
+
+@maybe_allow_in_graph
+class GatedSelfAttentionDense(nn.Module):
+ r"""
+ A gated self-attention dense layer that combines visual features and object features.
+
+ Parameters:
+ query_dim (`int`): The number of channels in the query.
+ context_dim (`int`): The number of channels in the context.
+ n_heads (`int`): The number of heads to use for attention.
+ d_head (`int`): The number of channels in each head.
+ """
+
+ def __init__(self, query_dim: int, context_dim: int, n_heads: int, d_head: int):
+ super().__init__()
+
+ # we need a linear projection since we need cat visual feature and obj feature
+ self.linear = nn.Linear(context_dim, query_dim)
+
+ self.attn = Attention(query_dim=query_dim, heads=n_heads, dim_head=d_head)
+ self.ff = FeedForward(query_dim, activation_fn="geglu")
+
+ self.norm1 = nn.LayerNorm(query_dim)
+ self.norm2 = nn.LayerNorm(query_dim)
+
+ self.register_parameter("alpha_attn", nn.Parameter(torch.tensor(0.0)))
+ self.register_parameter("alpha_dense", nn.Parameter(torch.tensor(0.0)))
+
+ self.enabled = True
+
+ def forward(self, x: torch.Tensor, objs: torch.Tensor) -> torch.Tensor:
+ if not self.enabled:
+ return x
+
+ n_visual = x.shape[1]
+ objs = self.linear(objs)
+
+ x = x + self.alpha_attn.tanh() * self.attn(self.norm1(torch.cat([x, objs], dim=1)))[:, :n_visual, :]
+ x = x + self.alpha_dense.tanh() * self.ff(self.norm2(x))
+
+ return x
+
+
+@maybe_allow_in_graph
+class JointTransformerBlock(nn.Module):
+ r"""
+ A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3.
+
+ Reference: https://arxiv.org/abs/2403.03206
+
+ Parameters:
+ dim (`int`): The number of channels in the input and output.
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
+ attention_head_dim (`int`): The number of channels in each head.
+ context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the
+ processing of `context` conditions.
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ num_attention_heads: int,
+ attention_head_dim: int,
+ context_pre_only: bool = False,
+ qk_norm: Optional[str] = None,
+ use_dual_attention: bool = False,
+ ):
+ super().__init__()
+
+ self.use_dual_attention = use_dual_attention
+ self.context_pre_only = context_pre_only
+ context_norm_type = "ada_norm_continous" if context_pre_only else "ada_norm_zero"
+
+ if use_dual_attention:
+ self.norm1 = SD35AdaLayerNormZeroX(dim)
+ else:
+ self.norm1 = AdaLayerNormZero(dim)
+
+ if context_norm_type == "ada_norm_continous":
+ self.norm1_context = AdaLayerNormContinuous(
+ dim, dim, elementwise_affine=False, eps=1e-6, bias=True, norm_type="layer_norm"
+ )
+ elif context_norm_type == "ada_norm_zero":
+ self.norm1_context = AdaLayerNormZero(dim)
+ else:
+ raise ValueError(
+ f"Unknown context_norm_type: {context_norm_type}, currently only support `ada_norm_continous`, `ada_norm_zero`"
+ )
+
+ if hasattr(F, "scaled_dot_product_attention"):
+ processor = JointAttnProcessor2_0()
+ else:
+ raise ValueError(
+ "The current PyTorch version does not support the `scaled_dot_product_attention` function."
+ )
+
+ self.attn = Attention(
+ query_dim=dim,
+ cross_attention_dim=None,
+ added_kv_proj_dim=dim,
+ dim_head=attention_head_dim,
+ heads=num_attention_heads,
+ out_dim=dim,
+ context_pre_only=context_pre_only,
+ bias=True,
+ processor=processor,
+ qk_norm=qk_norm,
+ eps=1e-6,
+ )
+
+ if use_dual_attention:
+ self.attn2 = Attention(
+ query_dim=dim,
+ cross_attention_dim=None,
+ dim_head=attention_head_dim,
+ heads=num_attention_heads,
+ out_dim=dim,
+ bias=True,
+ processor=processor,
+ qk_norm=qk_norm,
+ eps=1e-6,
+ )
+ else:
+ self.attn2 = None
+
+ self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
+ self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
+
+ if not context_pre_only:
+ self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
+ self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
+ else:
+ self.norm2_context = None
+ self.ff_context = None
+
+ # let chunk size default to None
+ self._chunk_size = None
+ self._chunk_dim = 0
+
+ # Copied from diffusers.models.attention.BasicTransformerBlock.set_chunk_feed_forward
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
+ # Sets chunk feed-forward
+ self._chunk_size = chunk_size
+ self._chunk_dim = dim
+
+ def forward(
+ self, hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor, temb: torch.FloatTensor
+ ):
+ if self.use_dual_attention:
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp, norm_hidden_states2, gate_msa2 = self.norm1(
+ hidden_states, emb=temb
+ )
+ else:
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
+
+ if self.context_pre_only:
+ norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states, temb)
+ else:
+ norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
+ encoder_hidden_states, emb=temb
+ )
+
+ # Attention.
+ attn_output, context_attn_output = self.attn(
+ hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states
+ )
+
+ # Process attention outputs for the `hidden_states`.
+ attn_output = gate_msa.unsqueeze(1) * attn_output
+ hidden_states = hidden_states + attn_output
+
+ if self.use_dual_attention:
+ attn_output2 = self.attn2(hidden_states=norm_hidden_states2)
+ attn_output2 = gate_msa2.unsqueeze(1) * attn_output2
+ hidden_states = hidden_states + attn_output2
+
+ norm_hidden_states = self.norm2(hidden_states)
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
+ if self._chunk_size is not None:
+ # "feed_forward_chunk_size" can be used to save memory
+ ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
+ else:
+ ff_output = self.ff(norm_hidden_states)
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
+
+ hidden_states = hidden_states + ff_output
+
+ # Process attention outputs for the `encoder_hidden_states`.
+ if self.context_pre_only:
+ encoder_hidden_states = None
+ else:
+ context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
+ encoder_hidden_states = encoder_hidden_states + context_attn_output
+
+ norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
+ norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
+ if self._chunk_size is not None:
+ # "feed_forward_chunk_size" can be used to save memory
+ context_ff_output = _chunked_feed_forward(
+ self.ff_context, norm_encoder_hidden_states, self._chunk_dim, self._chunk_size
+ )
+ else:
+ context_ff_output = self.ff_context(norm_encoder_hidden_states)
+ encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
+
+ return encoder_hidden_states, hidden_states
+
+
+@maybe_allow_in_graph
+class BasicTransformerBlock(nn.Module):
+ r"""
+ A basic Transformer block.
+
+ Parameters:
+ dim (`int`): The number of channels in the input and output.
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
+ attention_head_dim (`int`): The number of channels in each head.
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
+ cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
+ num_embeds_ada_norm (:
+ obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
+ attention_bias (:
+ obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
+ only_cross_attention (`bool`, *optional*):
+ Whether to use only cross-attention layers. In this case two cross attention layers are used.
+ double_self_attention (`bool`, *optional*):
+ Whether to use two self-attention layers. In this case no cross attention layers are used.
+ upcast_attention (`bool`, *optional*):
+ Whether to upcast the attention computation to float32. This is useful for mixed precision training.
+ norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
+ Whether to use learnable elementwise affine parameters for normalization.
+ norm_type (`str`, *optional*, defaults to `"layer_norm"`):
+ The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
+ final_dropout (`bool` *optional*, defaults to False):
+ Whether to apply a final dropout after the last feed-forward layer.
+ attention_type (`str`, *optional*, defaults to `"default"`):
+ The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
+ positional_embeddings (`str`, *optional*, defaults to `None`):
+ The type of positional embeddings to apply to.
+ num_positional_embeddings (`int`, *optional*, defaults to `None`):
+ The maximum number of positional embeddings to apply.
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ num_attention_heads: int,
+ attention_head_dim: int,
+ dropout=0.0,
+ cross_attention_dim: Optional[int] = None,
+ activation_fn: str = "geglu",
+ num_embeds_ada_norm: Optional[int] = None,
+ attention_bias: bool = False,
+ only_cross_attention: bool = False,
+ double_self_attention: bool = False,
+ upcast_attention: bool = False,
+ norm_elementwise_affine: bool = True,
+ norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single', 'ada_norm_continuous', 'layer_norm_i2vgen'
+ norm_eps: float = 1e-5,
+ final_dropout: bool = False,
+ attention_type: str = "default",
+ positional_embeddings: Optional[str] = None,
+ num_positional_embeddings: Optional[int] = None,
+ ada_norm_continous_conditioning_embedding_dim: Optional[int] = None,
+ ada_norm_bias: Optional[int] = None,
+ ff_inner_dim: Optional[int] = None,
+ ff_bias: bool = True,
+ attention_out_bias: bool = True,
+ ):
+ super().__init__()
+ self.dim = dim
+ self.num_attention_heads = num_attention_heads
+ self.attention_head_dim = attention_head_dim
+ self.dropout = dropout
+ self.cross_attention_dim = cross_attention_dim
+ self.activation_fn = activation_fn
+ self.attention_bias = attention_bias
+ self.double_self_attention = double_self_attention
+ self.norm_elementwise_affine = norm_elementwise_affine
+ self.positional_embeddings = positional_embeddings
+ self.num_positional_embeddings = num_positional_embeddings
+ self.only_cross_attention = only_cross_attention
+
+ # We keep these boolean flags for backward-compatibility.
+ self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
+ self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
+ self.use_ada_layer_norm_single = norm_type == "ada_norm_single"
+ self.use_layer_norm = norm_type == "layer_norm"
+ self.use_ada_layer_norm_continuous = norm_type == "ada_norm_continuous"
+
+ if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
+ raise ValueError(
+ f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
+ f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
+ )
+
+ self.norm_type = norm_type
+ self.num_embeds_ada_norm = num_embeds_ada_norm
+
+ if positional_embeddings and (num_positional_embeddings is None):
+ raise ValueError(
+ "If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined."
+ )
+
+ if positional_embeddings == "sinusoidal":
+ self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings)
+ else:
+ self.pos_embed = None
+
+ # Define 3 blocks. Each block has its own normalization layer.
+ # 1. Self-Attn
+ if norm_type == "ada_norm":
+ self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
+ elif norm_type == "ada_norm_zero":
+ self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
+ elif norm_type == "ada_norm_continuous":
+ self.norm1 = AdaLayerNormContinuous(
+ dim,
+ ada_norm_continous_conditioning_embedding_dim,
+ norm_elementwise_affine,
+ norm_eps,
+ ada_norm_bias,
+ "rms_norm",
+ )
+ else:
+ self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
+
+ self.attn1 = Attention(
+ query_dim=dim,
+ heads=num_attention_heads,
+ dim_head=attention_head_dim,
+ dropout=dropout,
+ bias=attention_bias,
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
+ upcast_attention=upcast_attention,
+ out_bias=attention_out_bias,
+ )
+
+ # 2. Cross-Attn
+ if cross_attention_dim is not None or double_self_attention:
+ # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
+ # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
+ # the second cross attention block.
+ if norm_type == "ada_norm":
+ self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm)
+ elif norm_type == "ada_norm_continuous":
+ self.norm2 = AdaLayerNormContinuous(
+ dim,
+ ada_norm_continous_conditioning_embedding_dim,
+ norm_elementwise_affine,
+ norm_eps,
+ ada_norm_bias,
+ "rms_norm",
+ )
+ else:
+ self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
+
+ self.attn2 = Attention(
+ query_dim=dim,
+ cross_attention_dim=cross_attention_dim if not double_self_attention else None,
+ heads=num_attention_heads,
+ dim_head=attention_head_dim,
+ dropout=dropout,
+ bias=attention_bias,
+ upcast_attention=upcast_attention,
+ out_bias=attention_out_bias,
+ ) # is self-attn if encoder_hidden_states is none
+ else:
+ if norm_type == "ada_norm_single": # For Latte
+ self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
+ else:
+ self.norm2 = None
+ self.attn2 = None
+
+ # 3. Feed-forward
+ if norm_type == "ada_norm_continuous":
+ self.norm3 = AdaLayerNormContinuous(
+ dim,
+ ada_norm_continous_conditioning_embedding_dim,
+ norm_elementwise_affine,
+ norm_eps,
+ ada_norm_bias,
+ "layer_norm",
+ )
+
+ elif norm_type in ["ada_norm_zero", "ada_norm", "layer_norm"]:
+ self.norm3 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
+ elif norm_type == "layer_norm_i2vgen":
+ self.norm3 = None
+
+ self.ff = FeedForward(
+ dim,
+ dropout=dropout,
+ activation_fn=activation_fn,
+ final_dropout=final_dropout,
+ inner_dim=ff_inner_dim,
+ bias=ff_bias,
+ )
+
+ # 4. Fuser
+ if attention_type == "gated" or attention_type == "gated-text-image":
+ self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim)
+
+ # 5. Scale-shift for PixArt-Alpha.
+ if norm_type == "ada_norm_single":
+ self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
+
+ # let chunk size default to None
+ self._chunk_size = None
+ self._chunk_dim = 0
+
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
+ # Sets chunk feed-forward
+ self._chunk_size = chunk_size
+ self._chunk_dim = dim
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ timestep: Optional[torch.LongTensor] = None,
+ cross_attention_kwargs: Dict[str, Any] = None,
+ class_labels: Optional[torch.LongTensor] = None,
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
+ ) -> torch.Tensor:
+ if cross_attention_kwargs is not None:
+ if cross_attention_kwargs.get("scale", None) is not None:
+ logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
+
+ # Notice that normalization is always applied before the real computation in the following blocks.
+ # 0. Self-Attention
+ batch_size = hidden_states.shape[0]
+
+ if self.norm_type == "ada_norm":
+ norm_hidden_states = self.norm1(hidden_states, timestep)
+ elif self.norm_type == "ada_norm_zero":
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
+ hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
+ )
+ elif self.norm_type in ["layer_norm", "layer_norm_i2vgen"]:
+ norm_hidden_states = self.norm1(hidden_states)
+ elif self.norm_type == "ada_norm_continuous":
+ norm_hidden_states = self.norm1(hidden_states, added_cond_kwargs["pooled_text_emb"])
+ elif self.norm_type == "ada_norm_single":
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
+ self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)
+ ).chunk(6, dim=1)
+ norm_hidden_states = self.norm1(hidden_states)
+ norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
+ else:
+ raise ValueError("Incorrect norm used")
+
+ if self.pos_embed is not None:
+ norm_hidden_states = self.pos_embed(norm_hidden_states)
+
+ # 1. Prepare GLIGEN inputs
+ cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
+ gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
+
+ attn_output = self.attn1(
+ norm_hidden_states,
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
+ attention_mask=attention_mask,
+ **cross_attention_kwargs,
+ )
+
+ if self.norm_type == "ada_norm_zero":
+ attn_output = gate_msa.unsqueeze(1) * attn_output
+ elif self.norm_type == "ada_norm_single":
+ attn_output = gate_msa * attn_output
+
+ hidden_states = attn_output + hidden_states
+ if hidden_states.ndim == 4:
+ hidden_states = hidden_states.squeeze(1)
+
+ # 1.2 GLIGEN Control
+ if gligen_kwargs is not None:
+ hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"])
+
+ # 3. Cross-Attention
+ if self.attn2 is not None:
+ if self.norm_type == "ada_norm":
+ norm_hidden_states = self.norm2(hidden_states, timestep)
+ elif self.norm_type in ["ada_norm_zero", "layer_norm", "layer_norm_i2vgen"]:
+ norm_hidden_states = self.norm2(hidden_states)
+ elif self.norm_type == "ada_norm_single":
+ # For PixArt norm2 isn't applied here:
+ # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103
+ norm_hidden_states = hidden_states
+ elif self.norm_type == "ada_norm_continuous":
+ norm_hidden_states = self.norm2(hidden_states, added_cond_kwargs["pooled_text_emb"])
+ else:
+ raise ValueError("Incorrect norm")
+
+ if self.pos_embed is not None and self.norm_type != "ada_norm_single":
+ norm_hidden_states = self.pos_embed(norm_hidden_states)
+
+ attn_output = self.attn2(
+ norm_hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=encoder_attention_mask,
+ **cross_attention_kwargs,
+ )
+ hidden_states = attn_output + hidden_states
+
+ # 4. Feed-forward
+ # i2vgen doesn't have this norm 🤷♂️
+ if self.norm_type == "ada_norm_continuous":
+ norm_hidden_states = self.norm3(hidden_states, added_cond_kwargs["pooled_text_emb"])
+ elif not self.norm_type == "ada_norm_single":
+ norm_hidden_states = self.norm3(hidden_states)
+
+ if self.norm_type == "ada_norm_zero":
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
+
+ if self.norm_type == "ada_norm_single":
+ norm_hidden_states = self.norm2(hidden_states)
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
+
+ if self._chunk_size is not None:
+ # "feed_forward_chunk_size" can be used to save memory
+ ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
+ else:
+ ff_output = self.ff(norm_hidden_states)
+
+ if self.norm_type == "ada_norm_zero":
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
+ elif self.norm_type == "ada_norm_single":
+ ff_output = gate_mlp * ff_output
+
+ hidden_states = ff_output + hidden_states
+ if hidden_states.ndim == 4:
+ hidden_states = hidden_states.squeeze(1)
+
+ return hidden_states
+
+
+class LuminaFeedForward(nn.Module):
+ r"""
+ A feed-forward layer.
+
+ Parameters:
+ hidden_size (`int`):
+ The dimensionality of the hidden layers in the model. This parameter determines the width of the model's
+ hidden representations.
+ intermediate_size (`int`): The intermediate dimension of the feedforward layer.
+ multiple_of (`int`, *optional*): Value to ensure hidden dimension is a multiple
+ of this value.
+ ffn_dim_multiplier (float, *optional*): Custom multiplier for hidden
+ dimension. Defaults to None.
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ inner_dim: int,
+ multiple_of: Optional[int] = 256,
+ ffn_dim_multiplier: Optional[float] = None,
+ ):
+ super().__init__()
+ inner_dim = int(2 * inner_dim / 3)
+ # custom hidden_size factor multiplier
+ if ffn_dim_multiplier is not None:
+ inner_dim = int(ffn_dim_multiplier * inner_dim)
+ inner_dim = multiple_of * ((inner_dim + multiple_of - 1) // multiple_of)
+
+ self.linear_1 = nn.Linear(
+ dim,
+ inner_dim,
+ bias=False,
+ )
+ self.linear_2 = nn.Linear(
+ inner_dim,
+ dim,
+ bias=False,
+ )
+ self.linear_3 = nn.Linear(
+ dim,
+ inner_dim,
+ bias=False,
+ )
+ self.silu = FP32SiLU()
+
+ def forward(self, x):
+ return self.linear_2(self.silu(self.linear_1(x)) * self.linear_3(x))
+
+
+@maybe_allow_in_graph
+class TemporalBasicTransformerBlock(nn.Module):
+ r"""
+ A basic Transformer block for video like data.
+
+ Parameters:
+ dim (`int`): The number of channels in the input and output.
+ time_mix_inner_dim (`int`): The number of channels for temporal attention.
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
+ attention_head_dim (`int`): The number of channels in each head.
+ cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ time_mix_inner_dim: int,
+ num_attention_heads: int,
+ attention_head_dim: int,
+ cross_attention_dim: Optional[int] = None,
+ ):
+ super().__init__()
+ self.is_res = dim == time_mix_inner_dim
+
+ self.norm_in = nn.LayerNorm(dim)
+
+ # Define 3 blocks. Each block has its own normalization layer.
+ # 1. Self-Attn
+ self.ff_in = FeedForward(
+ dim,
+ dim_out=time_mix_inner_dim,
+ activation_fn="geglu",
+ )
+
+ self.norm1 = nn.LayerNorm(time_mix_inner_dim)
+ self.attn1 = Attention(
+ query_dim=time_mix_inner_dim,
+ heads=num_attention_heads,
+ dim_head=attention_head_dim,
+ cross_attention_dim=None,
+ )
+
+ # 2. Cross-Attn
+ if cross_attention_dim is not None:
+ # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
+ # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
+ # the second cross attention block.
+ self.norm2 = nn.LayerNorm(time_mix_inner_dim)
+ self.attn2 = Attention(
+ query_dim=time_mix_inner_dim,
+ cross_attention_dim=cross_attention_dim,
+ heads=num_attention_heads,
+ dim_head=attention_head_dim,
+ ) # is self-attn if encoder_hidden_states is none
+ else:
+ self.norm2 = None
+ self.attn2 = None
+
+ # 3. Feed-forward
+ self.norm3 = nn.LayerNorm(time_mix_inner_dim)
+ self.ff = FeedForward(time_mix_inner_dim, activation_fn="geglu")
+
+ # let chunk size default to None
+ self._chunk_size = None
+ self._chunk_dim = None
+
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], **kwargs):
+ # Sets chunk feed-forward
+ self._chunk_size = chunk_size
+ # chunk dim should be hardcoded to 1 to have better speed vs. memory trade-off
+ self._chunk_dim = 1
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ num_frames: int,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ # Notice that normalization is always applied before the real computation in the following blocks.
+ # 0. Self-Attention
+ batch_size = hidden_states.shape[0]
+
+ batch_frames, seq_length, channels = hidden_states.shape
+ batch_size = batch_frames // num_frames
+
+ hidden_states = hidden_states[None, :].reshape(batch_size, num_frames, seq_length, channels)
+ hidden_states = hidden_states.permute(0, 2, 1, 3)
+ hidden_states = hidden_states.reshape(batch_size * seq_length, num_frames, channels)
+
+ residual = hidden_states
+ hidden_states = self.norm_in(hidden_states)
+
+ if self._chunk_size is not None:
+ hidden_states = _chunked_feed_forward(self.ff_in, hidden_states, self._chunk_dim, self._chunk_size)
+ else:
+ hidden_states = self.ff_in(hidden_states)
+
+ if self.is_res:
+ hidden_states = hidden_states + residual
+
+ norm_hidden_states = self.norm1(hidden_states)
+ attn_output = self.attn1(norm_hidden_states, encoder_hidden_states=None)
+ hidden_states = attn_output + hidden_states
+
+ # 3. Cross-Attention
+ if self.attn2 is not None:
+ norm_hidden_states = self.norm2(hidden_states)
+ attn_output = self.attn2(norm_hidden_states, encoder_hidden_states=encoder_hidden_states)
+ hidden_states = attn_output + hidden_states
+
+ # 4. Feed-forward
+ norm_hidden_states = self.norm3(hidden_states)
+
+ if self._chunk_size is not None:
+ ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
+ else:
+ ff_output = self.ff(norm_hidden_states)
+
+ if self.is_res:
+ hidden_states = ff_output + hidden_states
+ else:
+ hidden_states = ff_output
+
+ hidden_states = hidden_states[None, :].reshape(batch_size, seq_length, num_frames, channels)
+ hidden_states = hidden_states.permute(0, 2, 1, 3)
+ hidden_states = hidden_states.reshape(batch_size * num_frames, seq_length, channels)
+
+ return hidden_states
+
+
+class SkipFFTransformerBlock(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_attention_heads: int,
+ attention_head_dim: int,
+ kv_input_dim: int,
+ kv_input_dim_proj_use_bias: bool,
+ dropout=0.0,
+ cross_attention_dim: Optional[int] = None,
+ attention_bias: bool = False,
+ attention_out_bias: bool = True,
+ ):
+ super().__init__()
+ if kv_input_dim != dim:
+ self.kv_mapper = nn.Linear(kv_input_dim, dim, kv_input_dim_proj_use_bias)
+ else:
+ self.kv_mapper = None
+
+ self.norm1 = RMSNorm(dim, 1e-06)
+
+ self.attn1 = Attention(
+ query_dim=dim,
+ heads=num_attention_heads,
+ dim_head=attention_head_dim,
+ dropout=dropout,
+ bias=attention_bias,
+ cross_attention_dim=cross_attention_dim,
+ out_bias=attention_out_bias,
+ )
+
+ self.norm2 = RMSNorm(dim, 1e-06)
+
+ self.attn2 = Attention(
+ query_dim=dim,
+ cross_attention_dim=cross_attention_dim,
+ heads=num_attention_heads,
+ dim_head=attention_head_dim,
+ dropout=dropout,
+ bias=attention_bias,
+ out_bias=attention_out_bias,
+ )
+
+ def forward(self, hidden_states, encoder_hidden_states, cross_attention_kwargs):
+ cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
+
+ if self.kv_mapper is not None:
+ encoder_hidden_states = self.kv_mapper(F.silu(encoder_hidden_states))
+
+ norm_hidden_states = self.norm1(hidden_states)
+
+ attn_output = self.attn1(
+ norm_hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ **cross_attention_kwargs,
+ )
+
+ hidden_states = attn_output + hidden_states
+
+ norm_hidden_states = self.norm2(hidden_states)
+
+ attn_output = self.attn2(
+ norm_hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ **cross_attention_kwargs,
+ )
+
+ hidden_states = attn_output + hidden_states
+
+ return hidden_states
+
+
+@maybe_allow_in_graph
+class FreeNoiseTransformerBlock(nn.Module):
+ r"""
+ A FreeNoise Transformer block.
+
+ Parameters:
+ dim (`int`):
+ The number of channels in the input and output.
+ num_attention_heads (`int`):
+ The number of heads to use for multi-head attention.
+ attention_head_dim (`int`):
+ The number of channels in each head.
+ dropout (`float`, *optional*, defaults to 0.0):
+ The dropout probability to use.
+ cross_attention_dim (`int`, *optional*):
+ The size of the encoder_hidden_states vector for cross attention.
+ activation_fn (`str`, *optional*, defaults to `"geglu"`):
+ Activation function to be used in feed-forward.
+ num_embeds_ada_norm (`int`, *optional*):
+ The number of diffusion steps used during training. See `Transformer2DModel`.
+ attention_bias (`bool`, defaults to `False`):
+ Configure if the attentions should contain a bias parameter.
+ only_cross_attention (`bool`, defaults to `False`):
+ Whether to use only cross-attention layers. In this case two cross attention layers are used.
+ double_self_attention (`bool`, defaults to `False`):
+ Whether to use two self-attention layers. In this case no cross attention layers are used.
+ upcast_attention (`bool`, defaults to `False`):
+ Whether to upcast the attention computation to float32. This is useful for mixed precision training.
+ norm_elementwise_affine (`bool`, defaults to `True`):
+ Whether to use learnable elementwise affine parameters for normalization.
+ norm_type (`str`, defaults to `"layer_norm"`):
+ The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
+ final_dropout (`bool` defaults to `False`):
+ Whether to apply a final dropout after the last feed-forward layer.
+ attention_type (`str`, defaults to `"default"`):
+ The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
+ positional_embeddings (`str`, *optional*):
+ The type of positional embeddings to apply to.
+ num_positional_embeddings (`int`, *optional*, defaults to `None`):
+ The maximum number of positional embeddings to apply.
+ ff_inner_dim (`int`, *optional*):
+ Hidden dimension of feed-forward MLP.
+ ff_bias (`bool`, defaults to `True`):
+ Whether or not to use bias in feed-forward MLP.
+ attention_out_bias (`bool`, defaults to `True`):
+ Whether or not to use bias in attention output project layer.
+ context_length (`int`, defaults to `16`):
+ The maximum number of frames that the FreeNoise block processes at once.
+ context_stride (`int`, defaults to `4`):
+ The number of frames to be skipped before starting to process a new batch of `context_length` frames.
+ weighting_scheme (`str`, defaults to `"pyramid"`):
+ The weighting scheme to use for weighting averaging of processed latent frames. As described in the
+ Equation 9. of the [FreeNoise](https://arxiv.org/abs/2310.15169) paper, "pyramid" is the default setting
+ used.
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ num_attention_heads: int,
+ attention_head_dim: int,
+ dropout: float = 0.0,
+ cross_attention_dim: Optional[int] = None,
+ activation_fn: str = "geglu",
+ num_embeds_ada_norm: Optional[int] = None,
+ attention_bias: bool = False,
+ only_cross_attention: bool = False,
+ double_self_attention: bool = False,
+ upcast_attention: bool = False,
+ norm_elementwise_affine: bool = True,
+ norm_type: str = "layer_norm",
+ norm_eps: float = 1e-5,
+ final_dropout: bool = False,
+ positional_embeddings: Optional[str] = None,
+ num_positional_embeddings: Optional[int] = None,
+ ff_inner_dim: Optional[int] = None,
+ ff_bias: bool = True,
+ attention_out_bias: bool = True,
+ context_length: int = 16,
+ context_stride: int = 4,
+ weighting_scheme: str = "pyramid",
+ ):
+ super().__init__()
+ self.dim = dim
+ self.num_attention_heads = num_attention_heads
+ self.attention_head_dim = attention_head_dim
+ self.dropout = dropout
+ self.cross_attention_dim = cross_attention_dim
+ self.activation_fn = activation_fn
+ self.attention_bias = attention_bias
+ self.double_self_attention = double_self_attention
+ self.norm_elementwise_affine = norm_elementwise_affine
+ self.positional_embeddings = positional_embeddings
+ self.num_positional_embeddings = num_positional_embeddings
+ self.only_cross_attention = only_cross_attention
+
+ self.set_free_noise_properties(context_length, context_stride, weighting_scheme)
+
+ # We keep these boolean flags for backward-compatibility.
+ self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
+ self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
+ self.use_ada_layer_norm_single = norm_type == "ada_norm_single"
+ self.use_layer_norm = norm_type == "layer_norm"
+ self.use_ada_layer_norm_continuous = norm_type == "ada_norm_continuous"
+
+ if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
+ raise ValueError(
+ f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
+ f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
+ )
+
+ self.norm_type = norm_type
+ self.num_embeds_ada_norm = num_embeds_ada_norm
+
+ if positional_embeddings and (num_positional_embeddings is None):
+ raise ValueError(
+ "If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined."
+ )
+
+ if positional_embeddings == "sinusoidal":
+ self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings)
+ else:
+ self.pos_embed = None
+
+ # Define 3 blocks. Each block has its own normalization layer.
+ # 1. Self-Attn
+ self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
+
+ self.attn1 = Attention(
+ query_dim=dim,
+ heads=num_attention_heads,
+ dim_head=attention_head_dim,
+ dropout=dropout,
+ bias=attention_bias,
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
+ upcast_attention=upcast_attention,
+ out_bias=attention_out_bias,
+ )
+
+ # 2. Cross-Attn
+ if cross_attention_dim is not None or double_self_attention:
+ self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
+
+ self.attn2 = Attention(
+ query_dim=dim,
+ cross_attention_dim=cross_attention_dim if not double_self_attention else None,
+ heads=num_attention_heads,
+ dim_head=attention_head_dim,
+ dropout=dropout,
+ bias=attention_bias,
+ upcast_attention=upcast_attention,
+ out_bias=attention_out_bias,
+ ) # is self-attn if encoder_hidden_states is none
+
+ # 3. Feed-forward
+ self.ff = FeedForward(
+ dim,
+ dropout=dropout,
+ activation_fn=activation_fn,
+ final_dropout=final_dropout,
+ inner_dim=ff_inner_dim,
+ bias=ff_bias,
+ )
+
+ self.norm3 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
+
+ # let chunk size default to None
+ self._chunk_size = None
+ self._chunk_dim = 0
+
+ def _get_frame_indices(self, num_frames: int) -> List[Tuple[int, int]]:
+ frame_indices = []
+ for i in range(0, num_frames - self.context_length + 1, self.context_stride):
+ window_start = i
+ window_end = min(num_frames, i + self.context_length)
+ frame_indices.append((window_start, window_end))
+ return frame_indices
+
+ def _get_frame_weights(self, num_frames: int, weighting_scheme: str = "pyramid") -> List[float]:
+ if weighting_scheme == "flat":
+ weights = [1.0] * num_frames
+
+ elif weighting_scheme == "pyramid":
+ if num_frames % 2 == 0:
+ # num_frames = 4 => [1, 2, 2, 1]
+ mid = num_frames // 2
+ weights = list(range(1, mid + 1))
+ weights = weights + weights[::-1]
+ else:
+ # num_frames = 5 => [1, 2, 3, 2, 1]
+ mid = (num_frames + 1) // 2
+ weights = list(range(1, mid))
+ weights = weights + [mid] + weights[::-1]
+
+ elif weighting_scheme == "delayed_reverse_sawtooth":
+ if num_frames % 2 == 0:
+ # num_frames = 4 => [0.01, 2, 2, 1]
+ mid = num_frames // 2
+ weights = [0.01] * (mid - 1) + [mid]
+ weights = weights + list(range(mid, 0, -1))
+ else:
+ # num_frames = 5 => [0.01, 0.01, 3, 2, 1]
+ mid = (num_frames + 1) // 2
+ weights = [0.01] * mid
+ weights = weights + list(range(mid, 0, -1))
+ else:
+ raise ValueError(f"Unsupported value for weighting_scheme={weighting_scheme}")
+
+ return weights
+
+ def set_free_noise_properties(
+ self, context_length: int, context_stride: int, weighting_scheme: str = "pyramid"
+ ) -> None:
+ self.context_length = context_length
+ self.context_stride = context_stride
+ self.weighting_scheme = weighting_scheme
+
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0) -> None:
+ # Sets chunk feed-forward
+ self._chunk_size = chunk_size
+ self._chunk_dim = dim
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ cross_attention_kwargs: Dict[str, Any] = None,
+ *args,
+ **kwargs,
+ ) -> torch.Tensor:
+ if cross_attention_kwargs is not None:
+ if cross_attention_kwargs.get("scale", None) is not None:
+ logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
+
+ cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
+
+ # hidden_states: [B x H x W, F, C]
+ device = hidden_states.device
+ dtype = hidden_states.dtype
+
+ num_frames = hidden_states.size(1)
+ frame_indices = self._get_frame_indices(num_frames)
+ frame_weights = self._get_frame_weights(self.context_length, self.weighting_scheme)
+ frame_weights = torch.tensor(frame_weights, device=device, dtype=dtype).unsqueeze(0).unsqueeze(-1)
+ is_last_frame_batch_complete = frame_indices[-1][1] == num_frames
+
+ # Handle out-of-bounds case if num_frames isn't perfectly divisible by context_length
+ # For example, num_frames=25, context_length=16, context_stride=4, then we expect the ranges:
+ # [(0, 16), (4, 20), (8, 24), (10, 26)]
+ if not is_last_frame_batch_complete:
+ if num_frames < self.context_length:
+ raise ValueError(f"Expected {num_frames=} to be greater or equal than {self.context_length=}")
+ last_frame_batch_length = num_frames - frame_indices[-1][1]
+ frame_indices.append((num_frames - self.context_length, num_frames))
+
+ num_times_accumulated = torch.zeros((1, num_frames, 1), device=device)
+ accumulated_values = torch.zeros_like(hidden_states)
+
+ for i, (frame_start, frame_end) in enumerate(frame_indices):
+ # The reason for slicing here is to ensure that if (frame_end - frame_start) is to handle
+ # cases like frame_indices=[(0, 16), (16, 20)], if the user provided a video with 19 frames, or
+ # essentially a non-multiple of `context_length`.
+ weights = torch.ones_like(num_times_accumulated[:, frame_start:frame_end])
+ weights *= frame_weights
+
+ hidden_states_chunk = hidden_states[:, frame_start:frame_end]
+
+ # Notice that normalization is always applied before the real computation in the following blocks.
+ # 1. Self-Attention
+ norm_hidden_states = self.norm1(hidden_states_chunk)
+
+ if self.pos_embed is not None:
+ norm_hidden_states = self.pos_embed(norm_hidden_states)
+
+ attn_output = self.attn1(
+ norm_hidden_states,
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
+ attention_mask=attention_mask,
+ **cross_attention_kwargs,
+ )
+
+ hidden_states_chunk = attn_output + hidden_states_chunk
+ if hidden_states_chunk.ndim == 4:
+ hidden_states_chunk = hidden_states_chunk.squeeze(1)
+
+ # 2. Cross-Attention
+ if self.attn2 is not None:
+ norm_hidden_states = self.norm2(hidden_states_chunk)
+
+ if self.pos_embed is not None and self.norm_type != "ada_norm_single":
+ norm_hidden_states = self.pos_embed(norm_hidden_states)
+
+ attn_output = self.attn2(
+ norm_hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=encoder_attention_mask,
+ **cross_attention_kwargs,
+ )
+ hidden_states_chunk = attn_output + hidden_states_chunk
+
+ if i == len(frame_indices) - 1 and not is_last_frame_batch_complete:
+ accumulated_values[:, -last_frame_batch_length:] += (
+ hidden_states_chunk[:, -last_frame_batch_length:] * weights[:, -last_frame_batch_length:]
+ )
+ num_times_accumulated[:, -last_frame_batch_length:] += weights[:, -last_frame_batch_length]
+ else:
+ accumulated_values[:, frame_start:frame_end] += hidden_states_chunk * weights
+ num_times_accumulated[:, frame_start:frame_end] += weights
+
+ hidden_states = torch.cat(
+ [
+ torch.where(num_times_split > 0, accumulated_split / num_times_split, accumulated_split)
+ for accumulated_split, num_times_split in zip(
+ accumulated_values.split(self.context_length, dim=1),
+ num_times_accumulated.split(self.context_length, dim=1),
+ )
+ ],
+ dim=1,
+ ).to(dtype)
+
+ # 3. Feed-forward
+ norm_hidden_states = self.norm3(hidden_states)
+
+ if self._chunk_size is not None:
+ ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
+ else:
+ ff_output = self.ff(norm_hidden_states)
+
+ hidden_states = ff_output + hidden_states
+ if hidden_states.ndim == 4:
+ hidden_states = hidden_states.squeeze(1)
+
+ return hidden_states
+
+
+class FeedForward(nn.Module):
+ r"""
+ A feed-forward layer.
+
+ Parameters:
+ dim (`int`): The number of channels in the input.
+ dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
+ mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
+ final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
+ bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ dim_out: Optional[int] = None,
+ mult: int = 4,
+ dropout: float = 0.0,
+ activation_fn: str = "geglu",
+ final_dropout: bool = False,
+ inner_dim=None,
+ bias: bool = True,
+ ):
+ super().__init__()
+ if inner_dim is None:
+ inner_dim = int(dim * mult)
+ dim_out = dim_out if dim_out is not None else dim
+
+ if activation_fn == "gelu":
+ act_fn = GELU(dim, inner_dim, bias=bias)
+ if activation_fn == "gelu-approximate":
+ act_fn = GELU(dim, inner_dim, approximate="tanh", bias=bias)
+ elif activation_fn == "geglu":
+ act_fn = GEGLU(dim, inner_dim, bias=bias)
+ elif activation_fn == "geglu-approximate":
+ act_fn = ApproximateGELU(dim, inner_dim, bias=bias)
+ elif activation_fn == "swiglu":
+ act_fn = SwiGLU(dim, inner_dim, bias=bias)
+
+ self.net = nn.ModuleList([])
+ # project in
+ self.net.append(act_fn)
+ # project dropout
+ self.net.append(nn.Dropout(dropout))
+ # project out
+ self.net.append(nn.Linear(inner_dim, dim_out, bias=bias))
+ # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
+ if final_dropout:
+ self.net.append(nn.Dropout(dropout))
+
+ def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor:
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
+ deprecate("scale", "1.0.0", deprecation_message)
+ for module in self.net:
+ hidden_states = module(hidden_states)
+ return hidden_states
diff --git a/MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/cogvideox_5b/models/attention_processor.py b/MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/cogvideox_5b/models/attention_processor.py
new file mode 100644
index 0000000000000000000000000000000000000000..e1c3c42460b9d1b4ba8136ac4322c0443a1d03b7
--- /dev/null
+++ b/MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/cogvideox_5b/models/attention_processor.py
@@ -0,0 +1,4301 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import inspect
+import math
+from typing import Callable, List, Optional, Tuple, Union
+
+import torch
+import torch_npu
+import torch.nn.functional as F
+from torch import nn
+
+from diffusers.image_processor import IPAdapterMaskProcessor
+from diffusers.utils import deprecate, logging
+from diffusers.utils.import_utils import is_torch_npu_available, is_xformers_available
+from diffusers.utils.torch_utils import is_torch_version, maybe_allow_in_graph
+from ..utils.parallel_state import get_world_size,get_rank
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+MAX_TOKENS = 2147483647
+
+if is_torch_npu_available():
+ import torch_npu
+
+if is_xformers_available():
+ import xformers
+ import xformers.ops
+else:
+ xformers = None
+
+
+@maybe_allow_in_graph
+class Attention(nn.Module):
+ r"""
+ A cross attention layer.
+
+ Parameters:
+ query_dim (`int`):
+ The number of channels in the query.
+ cross_attention_dim (`int`, *optional*):
+ The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
+ heads (`int`, *optional*, defaults to 8):
+ The number of heads to use for multi-head attention.
+ kv_heads (`int`, *optional*, defaults to `None`):
+ The number of key and value heads to use for multi-head attention. Defaults to `heads`. If
+ `kv_heads=heads`, the model will use Multi Head Attention (MHA), if `kv_heads=1` the model will use Multi
+ Query Attention (MQA) otherwise GQA is used.
+ dim_head (`int`, *optional*, defaults to 64):
+ The number of channels in each head.
+ dropout (`float`, *optional*, defaults to 0.0):
+ The dropout probability to use.
+ bias (`bool`, *optional*, defaults to False):
+ Set to `True` for the query, key, and value linear layers to contain a bias parameter.
+ upcast_attention (`bool`, *optional*, defaults to False):
+ Set to `True` to upcast the attention computation to `float32`.
+ upcast_softmax (`bool`, *optional*, defaults to False):
+ Set to `True` to upcast the softmax computation to `float32`.
+ cross_attention_norm (`str`, *optional*, defaults to `None`):
+ The type of normalization to use for the cross attention. Can be `None`, `layer_norm`, or `group_norm`.
+ cross_attention_norm_num_groups (`int`, *optional*, defaults to 32):
+ The number of groups to use for the group norm in the cross attention.
+ added_kv_proj_dim (`int`, *optional*, defaults to `None`):
+ The number of channels to use for the added key and value projections. If `None`, no projection is used.
+ norm_num_groups (`int`, *optional*, defaults to `None`):
+ The number of groups to use for the group norm in the attention.
+ spatial_norm_dim (`int`, *optional*, defaults to `None`):
+ The number of channels to use for the spatial normalization.
+ out_bias (`bool`, *optional*, defaults to `True`):
+ Set to `True` to use a bias in the output linear layer.
+ scale_qk (`bool`, *optional*, defaults to `True`):
+ Set to `True` to scale the query and key by `1 / sqrt(dim_head)`.
+ only_cross_attention (`bool`, *optional*, defaults to `False`):
+ Set to `True` to only use cross attention and not added_kv_proj_dim. Can only be set to `True` if
+ `added_kv_proj_dim` is not `None`.
+ eps (`float`, *optional*, defaults to 1e-5):
+ An additional value added to the denominator in group normalization that is used for numerical stability.
+ rescale_output_factor (`float`, *optional*, defaults to 1.0):
+ A factor to rescale the output by dividing it with this value.
+ residual_connection (`bool`, *optional*, defaults to `False`):
+ Set to `True` to add the residual connection to the output.
+ _from_deprecated_attn_block (`bool`, *optional*, defaults to `False`):
+ Set to `True` if the attention block is loaded from a deprecated state dict.
+ processor (`AttnProcessor`, *optional*, defaults to `None`):
+ The attention processor to use. If `None`, defaults to `AttnProcessor2_0` if `torch 2.x` is used and
+ `AttnProcessor` otherwise.
+ """
+
+ def __init__(
+ self,
+ query_dim: int,
+ cross_attention_dim: Optional[int] = None,
+ heads: int = 8,
+ kv_heads: Optional[int] = None,
+ dim_head: int = 64,
+ dropout: float = 0.0,
+ bias: bool = False,
+ upcast_attention: bool = False,
+ upcast_softmax: bool = False,
+ cross_attention_norm: Optional[str] = None,
+ cross_attention_norm_num_groups: int = 32,
+ qk_norm: Optional[str] = None,
+ added_kv_proj_dim: Optional[int] = None,
+ added_proj_bias: Optional[bool] = True,
+ norm_num_groups: Optional[int] = None,
+ spatial_norm_dim: Optional[int] = None,
+ out_bias: bool = True,
+ scale_qk: bool = True,
+ only_cross_attention: bool = False,
+ eps: float = 1e-5,
+ rescale_output_factor: float = 1.0,
+ residual_connection: bool = False,
+ _from_deprecated_attn_block: bool = False,
+ processor: Optional["AttnProcessor"] = None,
+ out_dim: int = None,
+ context_pre_only=None,
+ pre_only=False,
+ elementwise_affine: bool = True,
+ ):
+ super().__init__()
+
+ # To prevent circular import.
+ from .normalization import FP32LayerNorm, RMSNorm
+
+ self.inner_dim = out_dim if out_dim is not None else dim_head * heads
+ self.inner_kv_dim = self.inner_dim if kv_heads is None else dim_head * kv_heads
+ self.query_dim = query_dim
+ self.use_bias = bias
+ self.is_cross_attention = cross_attention_dim is not None
+ self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
+ self.upcast_attention = upcast_attention
+ self.upcast_softmax = upcast_softmax
+ self.rescale_output_factor = rescale_output_factor
+ self.residual_connection = residual_connection
+ self.dropout = dropout
+ self.fused_projections = False
+ self.out_dim = out_dim if out_dim is not None else query_dim
+ self.context_pre_only = context_pre_only
+ self.pre_only = pre_only
+
+ # we make use of this private variable to know whether this class is loaded
+ # with an deprecated state dict so that we can convert it on the fly
+ self._from_deprecated_attn_block = _from_deprecated_attn_block
+
+ self.scale_qk = scale_qk
+ self.scale = dim_head**-0.5 if self.scale_qk else 1.0
+
+ self.heads = out_dim // dim_head if out_dim is not None else heads
+ # for slice_size > 0 the attention score computation
+ # is split across the batch axis to save memory
+ # You can set slice_size with `set_attention_slice`
+ self.sliceable_head_dim = heads
+
+ self.added_kv_proj_dim = added_kv_proj_dim
+ self.only_cross_attention = only_cross_attention
+
+ if self.added_kv_proj_dim is None and self.only_cross_attention:
+ raise ValueError(
+ "`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`."
+ )
+
+ if norm_num_groups is not None:
+ self.group_norm = nn.GroupNorm(num_channels=query_dim, num_groups=norm_num_groups, eps=eps, affine=True)
+ else:
+ self.group_norm = None
+
+ if spatial_norm_dim is not None:
+ self.spatial_norm = SpatialNorm(f_channels=query_dim, zq_channels=spatial_norm_dim)
+ else:
+ self.spatial_norm = None
+
+ if qk_norm is None:
+ self.norm_q = None
+ self.norm_k = None
+ elif qk_norm == "layer_norm":
+ self.norm_q = nn.LayerNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
+ self.norm_k = nn.LayerNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
+ elif qk_norm == "fp32_layer_norm":
+ self.norm_q = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)
+ self.norm_k = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)
+ elif qk_norm == "layer_norm_across_heads":
+ # Lumina applys qk norm across all heads
+ self.norm_q = nn.LayerNorm(dim_head * heads, eps=eps)
+ self.norm_k = nn.LayerNorm(dim_head * kv_heads, eps=eps)
+ elif qk_norm == "rms_norm":
+ self.norm_q = RMSNorm(dim_head, eps=eps)
+ self.norm_k = RMSNorm(dim_head, eps=eps)
+ else:
+ raise ValueError(f"unknown qk_norm: {qk_norm}. Should be None,'layer_norm','fp32_layer_norm','rms_norm'")
+
+ if cross_attention_norm is None:
+ self.norm_cross = None
+ elif cross_attention_norm == "layer_norm":
+ self.norm_cross = nn.LayerNorm(self.cross_attention_dim)
+ elif cross_attention_norm == "group_norm":
+ if self.added_kv_proj_dim is not None:
+ # The given `encoder_hidden_states` are initially of shape
+ # (batch_size, seq_len, added_kv_proj_dim) before being projected
+ # to (batch_size, seq_len, cross_attention_dim). The norm is applied
+ # before the projection, so we need to use `added_kv_proj_dim` as
+ # the number of channels for the group norm.
+ norm_cross_num_channels = added_kv_proj_dim
+ else:
+ norm_cross_num_channels = self.cross_attention_dim
+
+ self.norm_cross = nn.GroupNorm(
+ num_channels=norm_cross_num_channels, num_groups=cross_attention_norm_num_groups, eps=1e-5, affine=True
+ )
+ else:
+ raise ValueError(
+ f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'"
+ )
+
+ self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias)
+
+ if not self.only_cross_attention:
+ # only relevant for the `AddedKVProcessor` classes
+ self.to_k = nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias)
+ self.to_v = nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias)
+ else:
+ self.to_k = None
+ self.to_v = None
+
+ self.added_proj_bias = added_proj_bias
+ if self.added_kv_proj_dim is not None:
+ self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias)
+ self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias)
+ if self.context_pre_only is not None:
+ self.add_q_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
+
+ if not self.pre_only:
+ self.to_out = nn.ModuleList([])
+ self.to_out.append(nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
+ self.to_out.append(nn.Dropout(dropout))
+
+ if self.context_pre_only is not None and not self.context_pre_only:
+ self.to_add_out = nn.Linear(self.inner_dim, self.out_dim, bias=out_bias)
+
+ if qk_norm is not None and added_kv_proj_dim is not None:
+ if qk_norm == "fp32_layer_norm":
+ self.norm_added_q = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)
+ self.norm_added_k = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)
+ elif qk_norm == "rms_norm":
+ self.norm_added_q = RMSNorm(dim_head, eps=eps)
+ self.norm_added_k = RMSNorm(dim_head, eps=eps)
+ else:
+ raise ValueError(
+ f"unknown qk_norm: {qk_norm}. Should be one of `None,'layer_norm','fp32_layer_norm','rms_norm'`"
+ )
+ else:
+ self.norm_added_q = None
+ self.norm_added_k = None
+
+ if processor is None:
+ processor = (
+ AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor()
+ )
+ self.set_processor(processor)
+
+ def set_use_npu_flash_attention(self, use_npu_flash_attention: bool) -> None:
+ r"""
+ Set whether to use npu flash attention from `torch_npu` or not.
+
+ """
+ if use_npu_flash_attention:
+ processor = AttnProcessorNPU()
+ else:
+ processor = (
+ AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor()
+ )
+ self.set_processor(processor)
+
+ def set_use_memory_efficient_attention_xformers(
+ self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None
+ ) -> None:
+ r"""
+ Set whether to use memory efficient attention from `xformers` or not.
+
+ Args:
+ use_memory_efficient_attention_xformers (`bool`):
+ Whether to use memory efficient attention from `xformers` or not.
+ attention_op (`Callable`, *optional*):
+ The attention operation to use. Defaults to `None` which uses the default attention operation from
+ `xformers`.
+ """
+ is_custom_diffusion = hasattr(self, "processor") and isinstance(
+ self.processor,
+ (CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor, CustomDiffusionAttnProcessor2_0),
+ )
+ is_added_kv_processor = hasattr(self, "processor") and isinstance(
+ self.processor,
+ (
+ AttnAddedKVProcessor,
+ AttnAddedKVProcessor2_0,
+ SlicedAttnAddedKVProcessor,
+ XFormersAttnAddedKVProcessor,
+ ),
+ )
+
+ if use_memory_efficient_attention_xformers:
+ if is_added_kv_processor and is_custom_diffusion:
+ raise NotImplementedError(
+ f"Memory efficient attention is currently not supported for custom diffusion for attention processor type {self.processor}"
+ )
+ if not is_xformers_available():
+ raise ModuleNotFoundError(
+ (
+ "Refer to https://github.com/facebookresearch/xformers for more information on how to install"
+ " xformers"
+ ),
+ name="xformers",
+ )
+ elif not torch.cuda.is_available():
+ raise ValueError(
+ "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is"
+ " only available for GPU "
+ )
+ else:
+ try:
+ # Make sure we can run the memory efficient attention
+ _ = xformers.ops.memory_efficient_attention(
+ torch.randn((1, 2, 40), device="cuda"),
+ torch.randn((1, 2, 40), device="cuda"),
+ torch.randn((1, 2, 40), device="cuda"),
+ )
+ except Exception as e:
+ raise e
+
+ if is_custom_diffusion:
+ processor = CustomDiffusionXFormersAttnProcessor(
+ train_kv=self.processor.train_kv,
+ train_q_out=self.processor.train_q_out,
+ hidden_size=self.processor.hidden_size,
+ cross_attention_dim=self.processor.cross_attention_dim,
+ attention_op=attention_op,
+ )
+ processor.load_state_dict(self.processor.state_dict())
+ if hasattr(self.processor, "to_k_custom_diffusion"):
+ processor.to(self.processor.to_k_custom_diffusion.weight.device)
+ elif is_added_kv_processor:
+ logger.info(
+ "Memory efficient attention with `xformers` might currently not work correctly if an attention mask is required for the attention operation."
+ )
+ processor = XFormersAttnAddedKVProcessor(attention_op=attention_op)
+ else:
+ processor = XFormersAttnProcessor(attention_op=attention_op)
+ else:
+ if is_custom_diffusion:
+ attn_processor_class = (
+ CustomDiffusionAttnProcessor2_0
+ if hasattr(F, "scaled_dot_product_attention")
+ else CustomDiffusionAttnProcessor
+ )
+ processor = attn_processor_class(
+ train_kv=self.processor.train_kv,
+ train_q_out=self.processor.train_q_out,
+ hidden_size=self.processor.hidden_size,
+ cross_attention_dim=self.processor.cross_attention_dim,
+ )
+ processor.load_state_dict(self.processor.state_dict())
+ if hasattr(self.processor, "to_k_custom_diffusion"):
+ processor.to(self.processor.to_k_custom_diffusion.weight.device)
+ else:
+ processor = (
+ AttnProcessor2_0()
+ if hasattr(F, "scaled_dot_product_attention") and self.scale_qk
+ else AttnProcessor()
+ )
+
+ self.set_processor(processor)
+
+ def set_attention_slice(self, slice_size: int) -> None:
+ r"""
+ Set the slice size for attention computation.
+
+ Args:
+ slice_size (`int`):
+ The slice size for attention computation.
+ """
+ if slice_size is not None and slice_size > self.sliceable_head_dim:
+ raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.")
+
+ if slice_size is not None and self.added_kv_proj_dim is not None:
+ processor = SlicedAttnAddedKVProcessor(slice_size)
+ elif slice_size is not None:
+ processor = SlicedAttnProcessor(slice_size)
+ elif self.added_kv_proj_dim is not None:
+ processor = AttnAddedKVProcessor()
+ else:
+ processor = (
+ AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor()
+ )
+
+ self.set_processor(processor)
+
+ def set_processor(self, processor: "AttnProcessor") -> None:
+ r"""
+ Set the attention processor to use.
+
+ Args:
+ processor (`AttnProcessor`):
+ The attention processor to use.
+ """
+ # if current processor is in `self._modules` and if passed `processor` is not, we need to
+ # pop `processor` from `self._modules`
+ if (
+ hasattr(self, "processor")
+ and isinstance(self.processor, torch.nn.Module)
+ and not isinstance(processor, torch.nn.Module)
+ ):
+ logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}")
+ self._modules.pop("processor")
+
+ self.processor = processor
+
+ def get_processor(self, return_deprecated_lora: bool = False) -> "AttentionProcessor":
+ r"""
+ Get the attention processor in use.
+
+ Args:
+ return_deprecated_lora (`bool`, *optional*, defaults to `False`):
+ Set to `True` to return the deprecated LoRA attention processor.
+
+ Returns:
+ "AttentionProcessor": The attention processor in use.
+ """
+ if not return_deprecated_lora:
+ return self.processor
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ **cross_attention_kwargs,
+ ) -> torch.Tensor:
+ r"""
+ The forward method of the `Attention` class.
+
+ Args:
+ hidden_states (`torch.Tensor`):
+ The hidden states of the query.
+ encoder_hidden_states (`torch.Tensor`, *optional*):
+ The hidden states of the encoder.
+ attention_mask (`torch.Tensor`, *optional*):
+ The attention mask to use. If `None`, no mask is applied.
+ **cross_attention_kwargs:
+ Additional keyword arguments to pass along to the cross attention.
+
+ Returns:
+ `torch.Tensor`: The output of the attention layer.
+ """
+ # The `Attention` class can call different attention processors / attention functions
+ # here we simply pass along all tensors to the selected processor class
+ # For standard processors that are defined here, `**cross_attention_kwargs` is empty
+
+ attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())
+ quiet_attn_parameters = {"ip_adapter_masks"}
+ unused_kwargs = [
+ k for k, _ in cross_attention_kwargs.items() if k not in attn_parameters and k not in quiet_attn_parameters
+ ]
+ if len(unused_kwargs) > 0:
+ logger.warning(
+ f"cross_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored."
+ )
+ cross_attention_kwargs = {k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters}
+
+ return self.processor(
+ self,
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=attention_mask,
+ **cross_attention_kwargs,
+ )
+
+ def batch_to_head_dim(self, tensor: torch.Tensor) -> torch.Tensor:
+ r"""
+ Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size // heads, seq_len, dim * heads]`. `heads`
+ is the number of heads initialized while constructing the `Attention` class.
+
+ Args:
+ tensor (`torch.Tensor`): The tensor to reshape.
+
+ Returns:
+ `torch.Tensor`: The reshaped tensor.
+ """
+ head_size = self.heads
+ batch_size, seq_len, dim = tensor.shape
+ tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
+ return tensor
+
+ def head_to_batch_dim(self, tensor: torch.Tensor, out_dim: int = 3) -> torch.Tensor:
+ r"""
+ Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size, seq_len, heads, dim // heads]` `heads` is
+ the number of heads initialized while constructing the `Attention` class.
+
+ Args:
+ tensor (`torch.Tensor`): The tensor to reshape.
+ out_dim (`int`, *optional*, defaults to `3`): The output dimension of the tensor. If `3`, the tensor is
+ reshaped to `[batch_size * heads, seq_len, dim // heads]`.
+
+ Returns:
+ `torch.Tensor`: The reshaped tensor.
+ """
+ head_size = self.heads
+ if tensor.ndim == 3:
+ batch_size, seq_len, dim = tensor.shape
+ extra_dim = 1
+ else:
+ batch_size, extra_dim, seq_len, dim = tensor.shape
+ tensor = tensor.reshape(batch_size, seq_len * extra_dim, head_size, dim // head_size)
+ tensor = tensor.permute(0, 2, 1, 3)
+
+ if out_dim == 3:
+ tensor = tensor.reshape(batch_size * head_size, seq_len * extra_dim, dim // head_size)
+
+ return tensor
+
+ def get_attention_scores(
+ self, query: torch.Tensor, key: torch.Tensor, attention_mask: Optional[torch.Tensor] = None
+ ) -> torch.Tensor:
+ r"""
+ Compute the attention scores.
+
+ Args:
+ query (`torch.Tensor`): The query tensor.
+ key (`torch.Tensor`): The key tensor.
+ attention_mask (`torch.Tensor`, *optional*): The attention mask to use. If `None`, no mask is applied.
+
+ Returns:
+ `torch.Tensor`: The attention probabilities/scores.
+ """
+ dtype = query.dtype
+ if self.upcast_attention:
+ query = query.float()
+ key = key.float()
+
+ if attention_mask is None:
+ baddbmm_input = torch.empty(
+ query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device
+ )
+ beta = 0
+ else:
+ baddbmm_input = attention_mask
+ beta = 1
+
+ attention_scores = torch.baddbmm(
+ baddbmm_input,
+ query,
+ key.transpose(-1, -2),
+ beta=beta,
+ alpha=self.scale,
+ )
+ del baddbmm_input
+
+ if self.upcast_softmax:
+ attention_scores = attention_scores.float()
+
+ attention_probs = attention_scores.softmax(dim=-1)
+ del attention_scores
+
+ attention_probs = attention_probs.to(dtype)
+
+ return attention_probs
+
+ def prepare_attention_mask(
+ self, attention_mask: torch.Tensor, target_length: int, batch_size: int, out_dim: int = 3
+ ) -> torch.Tensor:
+ r"""
+ Prepare the attention mask for the attention computation.
+
+ Args:
+ attention_mask (`torch.Tensor`):
+ The attention mask to prepare.
+ target_length (`int`):
+ The target length of the attention mask. This is the length of the attention mask after padding.
+ batch_size (`int`):
+ The batch size, which is used to repeat the attention mask.
+ out_dim (`int`, *optional*, defaults to `3`):
+ The output dimension of the attention mask. Can be either `3` or `4`.
+
+ Returns:
+ `torch.Tensor`: The prepared attention mask.
+ """
+ head_size = self.heads
+ if attention_mask is None:
+ return attention_mask
+
+ current_length: int = attention_mask.shape[-1]
+ if current_length != target_length:
+ if attention_mask.device.type == "mps":
+ # HACK: MPS: Does not support padding by greater than dimension of input tensor.
+ # Instead, we can manually construct the padding tensor.
+ padding_shape = (attention_mask.shape[0], attention_mask.shape[1], target_length)
+ padding = torch.zeros(padding_shape, dtype=attention_mask.dtype, device=attention_mask.device)
+ attention_mask = torch.cat([attention_mask, padding], dim=2)
+ else:
+ attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
+
+ if out_dim == 3:
+ if attention_mask.shape[0] < batch_size * head_size:
+ attention_mask = attention_mask.repeat_interleave(head_size, dim=0)
+ elif out_dim == 4:
+ attention_mask = attention_mask.unsqueeze(1)
+ attention_mask = attention_mask.repeat_interleave(head_size, dim=1)
+
+ return attention_mask
+
+ def norm_encoder_hidden_states(self, encoder_hidden_states: torch.Tensor) -> torch.Tensor:
+ r"""
+ Normalize the encoder hidden states. Requires `self.norm_cross` to be specified when constructing the
+ `Attention` class.
+
+ Args:
+ encoder_hidden_states (`torch.Tensor`): Hidden states of the encoder.
+
+ Returns:
+ `torch.Tensor`: The normalized encoder hidden states.
+ """
+ assert self.norm_cross is not None, "self.norm_cross must be defined to call self.norm_encoder_hidden_states"
+
+ if isinstance(self.norm_cross, nn.LayerNorm):
+ encoder_hidden_states = self.norm_cross(encoder_hidden_states)
+ elif isinstance(self.norm_cross, nn.GroupNorm):
+ # Group norm norms along the channels dimension and expects
+ # input to be in the shape of (N, C, *). In this case, we want
+ # to norm along the hidden dimension, so we need to move
+ # (batch_size, sequence_length, hidden_size) ->
+ # (batch_size, hidden_size, sequence_length)
+ encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
+ encoder_hidden_states = self.norm_cross(encoder_hidden_states)
+ encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
+ else:
+ assert False
+
+ return encoder_hidden_states
+
+ @torch.no_grad()
+ def fuse_projections(self, fuse=True):
+ device = self.to_q.weight.data.device
+ dtype = self.to_q.weight.data.dtype
+
+ if not self.is_cross_attention:
+ # fetch weight matrices.
+ concatenated_weights = torch.cat([self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data])
+ in_features = concatenated_weights.shape[1]
+ out_features = concatenated_weights.shape[0]
+
+ # create a new single projection layer and copy over the weights.
+ self.to_qkv = nn.Linear(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype)
+ self.to_qkv.weight.copy_(concatenated_weights)
+ if self.use_bias:
+ concatenated_bias = torch.cat([self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data])
+ self.to_qkv.bias.copy_(concatenated_bias)
+
+ else:
+ concatenated_weights = torch.cat([self.to_k.weight.data, self.to_v.weight.data])
+ in_features = concatenated_weights.shape[1]
+ out_features = concatenated_weights.shape[0]
+
+ self.to_kv = nn.Linear(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype)
+ self.to_kv.weight.copy_(concatenated_weights)
+ if self.use_bias:
+ concatenated_bias = torch.cat([self.to_k.bias.data, self.to_v.bias.data])
+ self.to_kv.bias.copy_(concatenated_bias)
+
+ # handle added projections for SD3 and others.
+ if hasattr(self, "add_q_proj") and hasattr(self, "add_k_proj") and hasattr(self, "add_v_proj"):
+ concatenated_weights = torch.cat(
+ [self.add_q_proj.weight.data, self.add_k_proj.weight.data, self.add_v_proj.weight.data]
+ )
+ in_features = concatenated_weights.shape[1]
+ out_features = concatenated_weights.shape[0]
+
+ self.to_added_qkv = nn.Linear(
+ in_features, out_features, bias=self.added_proj_bias, device=device, dtype=dtype
+ )
+ self.to_added_qkv.weight.copy_(concatenated_weights)
+ if self.added_proj_bias:
+ concatenated_bias = torch.cat(
+ [self.add_q_proj.bias.data, self.add_k_proj.bias.data, self.add_v_proj.bias.data]
+ )
+ self.to_added_qkv.bias.copy_(concatenated_bias)
+
+ self.fused_projections = fuse
+
+
+class AttnProcessor:
+ r"""
+ Default processor for performing attention-related computations.
+ """
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ temb: Optional[torch.Tensor] = None,
+ *args,
+ **kwargs,
+ ) -> torch.Tensor:
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
+ deprecate("scale", "1.0.0", deprecation_message)
+
+ residual = hidden_states
+
+ if attn.spatial_norm is not None:
+ hidden_states = attn.spatial_norm(hidden_states, temb)
+
+ input_ndim = hidden_states.ndim
+
+ if input_ndim == 4:
+ batch_size, channel, height, width = hidden_states.shape
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+
+ batch_size, sequence_length, _ = (
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+ )
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+
+ if attn.group_norm is not None:
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+ query = attn.to_q(hidden_states)
+
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+ elif attn.norm_cross:
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+ key = attn.to_k(encoder_hidden_states)
+ value = attn.to_v(encoder_hidden_states)
+
+ query = attn.head_to_batch_dim(query)
+ key = attn.head_to_batch_dim(key)
+ value = attn.head_to_batch_dim(value)
+
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
+ hidden_states = torch.bmm(attention_probs, value)
+ hidden_states = attn.batch_to_head_dim(hidden_states)
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ if input_ndim == 4:
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
+
+ if attn.residual_connection:
+ hidden_states = hidden_states + residual
+
+ hidden_states = hidden_states / attn.rescale_output_factor
+
+ return hidden_states
+
+
+class CustomDiffusionAttnProcessor(nn.Module):
+ r"""
+ Processor for implementing attention for the Custom Diffusion method.
+
+ Args:
+ train_kv (`bool`, defaults to `True`):
+ Whether to newly train the key and value matrices corresponding to the text features.
+ train_q_out (`bool`, defaults to `True`):
+ Whether to newly train query matrices corresponding to the latent image features.
+ hidden_size (`int`, *optional*, defaults to `None`):
+ The hidden size of the attention layer.
+ cross_attention_dim (`int`, *optional*, defaults to `None`):
+ The number of channels in the `encoder_hidden_states`.
+ out_bias (`bool`, defaults to `True`):
+ Whether to include the bias parameter in `train_q_out`.
+ dropout (`float`, *optional*, defaults to 0.0):
+ The dropout probability to use.
+ """
+
+ def __init__(
+ self,
+ train_kv: bool = True,
+ train_q_out: bool = True,
+ hidden_size: Optional[int] = None,
+ cross_attention_dim: Optional[int] = None,
+ out_bias: bool = True,
+ dropout: float = 0.0,
+ ):
+ super().__init__()
+ self.train_kv = train_kv
+ self.train_q_out = train_q_out
+
+ self.hidden_size = hidden_size
+ self.cross_attention_dim = cross_attention_dim
+
+ # `_custom_diffusion` id for easy serialization and loading.
+ if self.train_kv:
+ self.to_k_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
+ self.to_v_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
+ if self.train_q_out:
+ self.to_q_custom_diffusion = nn.Linear(hidden_size, hidden_size, bias=False)
+ self.to_out_custom_diffusion = nn.ModuleList([])
+ self.to_out_custom_diffusion.append(nn.Linear(hidden_size, hidden_size, bias=out_bias))
+ self.to_out_custom_diffusion.append(nn.Dropout(dropout))
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ batch_size, sequence_length, _ = hidden_states.shape
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+ if self.train_q_out:
+ query = self.to_q_custom_diffusion(hidden_states).to(attn.to_q.weight.dtype)
+ else:
+ query = attn.to_q(hidden_states.to(attn.to_q.weight.dtype))
+
+ if encoder_hidden_states is None:
+ crossattn = False
+ encoder_hidden_states = hidden_states
+ else:
+ crossattn = True
+ if attn.norm_cross:
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+ if self.train_kv:
+ key = self.to_k_custom_diffusion(encoder_hidden_states.to(self.to_k_custom_diffusion.weight.dtype))
+ value = self.to_v_custom_diffusion(encoder_hidden_states.to(self.to_v_custom_diffusion.weight.dtype))
+ key = key.to(attn.to_q.weight.dtype)
+ value = value.to(attn.to_q.weight.dtype)
+ else:
+ key = attn.to_k(encoder_hidden_states)
+ value = attn.to_v(encoder_hidden_states)
+
+ if crossattn:
+ detach = torch.ones_like(key)
+ detach[:, :1, :] = detach[:, :1, :] * 0.0
+ key = detach * key + (1 - detach) * key.detach()
+ value = detach * value + (1 - detach) * value.detach()
+
+ query = attn.head_to_batch_dim(query)
+ key = attn.head_to_batch_dim(key)
+ value = attn.head_to_batch_dim(value)
+
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
+ hidden_states = torch.bmm(attention_probs, value)
+ hidden_states = attn.batch_to_head_dim(hidden_states)
+
+ if self.train_q_out:
+ # linear proj
+ hidden_states = self.to_out_custom_diffusion[0](hidden_states)
+ # dropout
+ hidden_states = self.to_out_custom_diffusion[1](hidden_states)
+ else:
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ return hidden_states
+
+
+class AttnAddedKVProcessor:
+ r"""
+ Processor for performing attention-related computations with extra learnable key and value matrices for the text
+ encoder.
+ """
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ *args,
+ **kwargs,
+ ) -> torch.Tensor:
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
+ deprecate("scale", "1.0.0", deprecation_message)
+
+ residual = hidden_states
+
+ hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
+ batch_size, sequence_length, _ = hidden_states.shape
+
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+ elif attn.norm_cross:
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+ query = attn.to_q(hidden_states)
+ query = attn.head_to_batch_dim(query)
+
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
+ encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
+ encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
+
+ if not attn.only_cross_attention:
+ key = attn.to_k(hidden_states)
+ value = attn.to_v(hidden_states)
+ key = attn.head_to_batch_dim(key)
+ value = attn.head_to_batch_dim(value)
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
+ else:
+ key = encoder_hidden_states_key_proj
+ value = encoder_hidden_states_value_proj
+
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
+ hidden_states = torch.bmm(attention_probs, value)
+ hidden_states = attn.batch_to_head_dim(hidden_states)
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
+ hidden_states = hidden_states + residual
+
+ return hidden_states
+
+
+class AttnAddedKVProcessor2_0:
+ r"""
+ Processor for performing scaled dot-product attention (enabled by default if you're using PyTorch 2.0), with extra
+ learnable key and value matrices for the text encoder.
+ """
+
+ def __init__(self):
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError(
+ "AttnAddedKVProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
+ )
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ *args,
+ **kwargs,
+ ) -> torch.Tensor:
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
+ deprecate("scale", "1.0.0", deprecation_message)
+
+ residual = hidden_states
+
+ hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
+ batch_size, sequence_length, _ = hidden_states.shape
+
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size, out_dim=4)
+
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+ elif attn.norm_cross:
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+ query = attn.to_q(hidden_states)
+ query = attn.head_to_batch_dim(query, out_dim=4)
+
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
+ encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj, out_dim=4)
+ encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj, out_dim=4)
+
+ if not attn.only_cross_attention:
+ key = attn.to_k(hidden_states)
+ value = attn.to_v(hidden_states)
+ key = attn.head_to_batch_dim(key, out_dim=4)
+ value = attn.head_to_batch_dim(value, out_dim=4)
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
+ else:
+ key = encoder_hidden_states_key_proj
+ value = encoder_hidden_states_value_proj
+
+ hidden_states = F.scaled_dot_product_attention(
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
+ )
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, residual.shape[1])
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
+ hidden_states = hidden_states + residual
+
+ return hidden_states
+
+
+class JointAttnProcessor2_0:
+ """Attention processor used typically in processing the SD3-like self-attention projections."""
+
+ def __init__(self):
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: torch.FloatTensor,
+ encoder_hidden_states: torch.FloatTensor = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ *args,
+ **kwargs,
+ ) -> torch.FloatTensor:
+ residual = hidden_states
+
+ batch_size = hidden_states.shape[0]
+
+ # `sample` projections.
+ query = attn.to_q(hidden_states)
+ key = attn.to_k(hidden_states)
+ value = attn.to_v(hidden_states)
+
+ inner_dim = key.shape[-1]
+ head_dim = inner_dim // attn.heads
+
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ if attn.norm_q is not None:
+ query = attn.norm_q(query)
+ if attn.norm_k is not None:
+ key = attn.norm_k(key)
+
+ # `context` projections.
+ if encoder_hidden_states is not None:
+ encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
+
+ encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
+ batch_size, -1, attn.heads, head_dim
+ ).transpose(1, 2)
+ encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
+ batch_size, -1, attn.heads, head_dim
+ ).transpose(1, 2)
+ encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
+ batch_size, -1, attn.heads, head_dim
+ ).transpose(1, 2)
+
+ if attn.norm_added_q is not None:
+ encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
+ if attn.norm_added_k is not None:
+ encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
+
+ query = torch.cat([query, encoder_hidden_states_query_proj], dim=2)
+ key = torch.cat([key, encoder_hidden_states_key_proj], dim=2)
+ value = torch.cat([value, encoder_hidden_states_value_proj], dim=2)
+
+ hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+ hidden_states = hidden_states.to(query.dtype)
+
+ if encoder_hidden_states is not None:
+ # Split the attention outputs.
+ hidden_states, encoder_hidden_states = (
+ hidden_states[:, : residual.shape[1]],
+ hidden_states[:, residual.shape[1] :],
+ )
+ if not attn.context_pre_only:
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ if encoder_hidden_states is not None:
+ return hidden_states, encoder_hidden_states
+ else:
+ return hidden_states
+
+
+class PAGJointAttnProcessor2_0:
+ """Attention processor used typically in processing the SD3-like self-attention projections."""
+
+ def __init__(self):
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError(
+ "PAGJointAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
+ )
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: torch.FloatTensor,
+ encoder_hidden_states: torch.FloatTensor = None,
+ ) -> torch.FloatTensor:
+ residual = hidden_states
+
+ input_ndim = hidden_states.ndim
+ if input_ndim == 4:
+ batch_size, channel, height, width = hidden_states.shape
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+ context_input_ndim = encoder_hidden_states.ndim
+ if context_input_ndim == 4:
+ batch_size, channel, height, width = encoder_hidden_states.shape
+ encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+
+ # store the length of image patch sequences to create a mask that prevents interaction between patches
+ # similar to making the self-attention map an identity matrix
+ identity_block_size = hidden_states.shape[1]
+
+ # chunk
+ hidden_states_org, hidden_states_ptb = hidden_states.chunk(2)
+ encoder_hidden_states_org, encoder_hidden_states_ptb = encoder_hidden_states.chunk(2)
+
+ ################## original path ##################
+ batch_size = encoder_hidden_states_org.shape[0]
+
+ # `sample` projections.
+ query_org = attn.to_q(hidden_states_org)
+ key_org = attn.to_k(hidden_states_org)
+ value_org = attn.to_v(hidden_states_org)
+
+ # `context` projections.
+ encoder_hidden_states_org_query_proj = attn.add_q_proj(encoder_hidden_states_org)
+ encoder_hidden_states_org_key_proj = attn.add_k_proj(encoder_hidden_states_org)
+ encoder_hidden_states_org_value_proj = attn.add_v_proj(encoder_hidden_states_org)
+
+ # attention
+ query_org = torch.cat([query_org, encoder_hidden_states_org_query_proj], dim=1)
+ key_org = torch.cat([key_org, encoder_hidden_states_org_key_proj], dim=1)
+ value_org = torch.cat([value_org, encoder_hidden_states_org_value_proj], dim=1)
+
+ inner_dim = key_org.shape[-1]
+ head_dim = inner_dim // attn.heads
+ query_org = query_org.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ key_org = key_org.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ value_org = value_org.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ hidden_states_org = F.scaled_dot_product_attention(
+ query_org, key_org, value_org, dropout_p=0.0, is_causal=False
+ )
+ hidden_states_org = hidden_states_org.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+ hidden_states_org = hidden_states_org.to(query_org.dtype)
+
+ # Split the attention outputs.
+ hidden_states_org, encoder_hidden_states_org = (
+ hidden_states_org[:, : residual.shape[1]],
+ hidden_states_org[:, residual.shape[1] :],
+ )
+
+ # linear proj
+ hidden_states_org = attn.to_out[0](hidden_states_org)
+ # dropout
+ hidden_states_org = attn.to_out[1](hidden_states_org)
+ if not attn.context_pre_only:
+ encoder_hidden_states_org = attn.to_add_out(encoder_hidden_states_org)
+
+ if input_ndim == 4:
+ hidden_states_org = hidden_states_org.transpose(-1, -2).reshape(batch_size, channel, height, width)
+ if context_input_ndim == 4:
+ encoder_hidden_states_org = encoder_hidden_states_org.transpose(-1, -2).reshape(
+ batch_size, channel, height, width
+ )
+
+ ################## perturbed path ##################
+ batch_size = encoder_hidden_states_ptb.shape[0]
+
+ # `sample` projections.
+ query_ptb = attn.to_q(hidden_states_ptb)
+ key_ptb = attn.to_k(hidden_states_ptb)
+ value_ptb = attn.to_v(hidden_states_ptb)
+
+ # `context` projections.
+ encoder_hidden_states_ptb_query_proj = attn.add_q_proj(encoder_hidden_states_ptb)
+ encoder_hidden_states_ptb_key_proj = attn.add_k_proj(encoder_hidden_states_ptb)
+ encoder_hidden_states_ptb_value_proj = attn.add_v_proj(encoder_hidden_states_ptb)
+
+ # attention
+ query_ptb = torch.cat([query_ptb, encoder_hidden_states_ptb_query_proj], dim=1)
+ key_ptb = torch.cat([key_ptb, encoder_hidden_states_ptb_key_proj], dim=1)
+ value_ptb = torch.cat([value_ptb, encoder_hidden_states_ptb_value_proj], dim=1)
+
+ inner_dim = key_ptb.shape[-1]
+ head_dim = inner_dim // attn.heads
+ query_ptb = query_ptb.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ key_ptb = key_ptb.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ value_ptb = value_ptb.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ # create a full mask with all entries set to 0
+ seq_len = query_ptb.size(2)
+ full_mask = torch.zeros((seq_len, seq_len), device=query_ptb.device, dtype=query_ptb.dtype)
+
+ # set the attention value between image patches to -inf
+ full_mask[:identity_block_size, :identity_block_size] = float("-inf")
+
+ # set the diagonal of the attention value between image patches to 0
+ full_mask[:identity_block_size, :identity_block_size].fill_diagonal_(0)
+
+ # expand the mask to match the attention weights shape
+ full_mask = full_mask.unsqueeze(0).unsqueeze(0) # Add batch and num_heads dimensions
+
+ hidden_states_ptb = F.scaled_dot_product_attention(
+ query_ptb, key_ptb, value_ptb, attn_mask=full_mask, dropout_p=0.0, is_causal=False
+ )
+ hidden_states_ptb = hidden_states_ptb.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+ hidden_states_ptb = hidden_states_ptb.to(query_ptb.dtype)
+
+ # split the attention outputs.
+ hidden_states_ptb, encoder_hidden_states_ptb = (
+ hidden_states_ptb[:, : residual.shape[1]],
+ hidden_states_ptb[:, residual.shape[1] :],
+ )
+
+ # linear proj
+ hidden_states_ptb = attn.to_out[0](hidden_states_ptb)
+ # dropout
+ hidden_states_ptb = attn.to_out[1](hidden_states_ptb)
+ if not attn.context_pre_only:
+ encoder_hidden_states_ptb = attn.to_add_out(encoder_hidden_states_ptb)
+
+ if input_ndim == 4:
+ hidden_states_ptb = hidden_states_ptb.transpose(-1, -2).reshape(batch_size, channel, height, width)
+ if context_input_ndim == 4:
+ encoder_hidden_states_ptb = encoder_hidden_states_ptb.transpose(-1, -2).reshape(
+ batch_size, channel, height, width
+ )
+
+ ################ concat ###############
+ hidden_states = torch.cat([hidden_states_org, hidden_states_ptb])
+ encoder_hidden_states = torch.cat([encoder_hidden_states_org, encoder_hidden_states_ptb])
+
+ return hidden_states, encoder_hidden_states
+
+
+class PAGCFGJointAttnProcessor2_0:
+ """Attention processor used typically in processing the SD3-like self-attention projections."""
+
+ def __init__(self):
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError(
+ "PAGCFGJointAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
+ )
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: torch.FloatTensor,
+ encoder_hidden_states: torch.FloatTensor = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ *args,
+ **kwargs,
+ ) -> torch.FloatTensor:
+ residual = hidden_states
+
+ input_ndim = hidden_states.ndim
+ if input_ndim == 4:
+ batch_size, channel, height, width = hidden_states.shape
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+ context_input_ndim = encoder_hidden_states.ndim
+ if context_input_ndim == 4:
+ batch_size, channel, height, width = encoder_hidden_states.shape
+ encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+
+ identity_block_size = hidden_states.shape[
+ 1
+ ] # patch embeddings width * height (correspond to self-attention map width or height)
+
+ # chunk
+ hidden_states_uncond, hidden_states_org, hidden_states_ptb = hidden_states.chunk(3)
+ hidden_states_org = torch.cat([hidden_states_uncond, hidden_states_org])
+
+ (
+ encoder_hidden_states_uncond,
+ encoder_hidden_states_org,
+ encoder_hidden_states_ptb,
+ ) = encoder_hidden_states.chunk(3)
+ encoder_hidden_states_org = torch.cat([encoder_hidden_states_uncond, encoder_hidden_states_org])
+
+ ################## original path ##################
+ batch_size = encoder_hidden_states_org.shape[0]
+
+ # `sample` projections.
+ query_org = attn.to_q(hidden_states_org)
+ key_org = attn.to_k(hidden_states_org)
+ value_org = attn.to_v(hidden_states_org)
+
+ # `context` projections.
+ encoder_hidden_states_org_query_proj = attn.add_q_proj(encoder_hidden_states_org)
+ encoder_hidden_states_org_key_proj = attn.add_k_proj(encoder_hidden_states_org)
+ encoder_hidden_states_org_value_proj = attn.add_v_proj(encoder_hidden_states_org)
+
+ # attention
+ query_org = torch.cat([query_org, encoder_hidden_states_org_query_proj], dim=1)
+ key_org = torch.cat([key_org, encoder_hidden_states_org_key_proj], dim=1)
+ value_org = torch.cat([value_org, encoder_hidden_states_org_value_proj], dim=1)
+
+ inner_dim = key_org.shape[-1]
+ head_dim = inner_dim // attn.heads
+ query_org = query_org.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ key_org = key_org.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ value_org = value_org.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ hidden_states_org = F.scaled_dot_product_attention(
+ query_org, key_org, value_org, dropout_p=0.0, is_causal=False
+ )
+ hidden_states_org = hidden_states_org.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+ hidden_states_org = hidden_states_org.to(query_org.dtype)
+
+ # Split the attention outputs.
+ hidden_states_org, encoder_hidden_states_org = (
+ hidden_states_org[:, : residual.shape[1]],
+ hidden_states_org[:, residual.shape[1] :],
+ )
+
+ # linear proj
+ hidden_states_org = attn.to_out[0](hidden_states_org)
+ # dropout
+ hidden_states_org = attn.to_out[1](hidden_states_org)
+ if not attn.context_pre_only:
+ encoder_hidden_states_org = attn.to_add_out(encoder_hidden_states_org)
+
+ if input_ndim == 4:
+ hidden_states_org = hidden_states_org.transpose(-1, -2).reshape(batch_size, channel, height, width)
+ if context_input_ndim == 4:
+ encoder_hidden_states_org = encoder_hidden_states_org.transpose(-1, -2).reshape(
+ batch_size, channel, height, width
+ )
+
+ ################## perturbed path ##################
+ batch_size = encoder_hidden_states_ptb.shape[0]
+
+ # `sample` projections.
+ query_ptb = attn.to_q(hidden_states_ptb)
+ key_ptb = attn.to_k(hidden_states_ptb)
+ value_ptb = attn.to_v(hidden_states_ptb)
+
+ # `context` projections.
+ encoder_hidden_states_ptb_query_proj = attn.add_q_proj(encoder_hidden_states_ptb)
+ encoder_hidden_states_ptb_key_proj = attn.add_k_proj(encoder_hidden_states_ptb)
+ encoder_hidden_states_ptb_value_proj = attn.add_v_proj(encoder_hidden_states_ptb)
+
+ # attention
+ query_ptb = torch.cat([query_ptb, encoder_hidden_states_ptb_query_proj], dim=1)
+ key_ptb = torch.cat([key_ptb, encoder_hidden_states_ptb_key_proj], dim=1)
+ value_ptb = torch.cat([value_ptb, encoder_hidden_states_ptb_value_proj], dim=1)
+
+ inner_dim = key_ptb.shape[-1]
+ head_dim = inner_dim // attn.heads
+ query_ptb = query_ptb.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ key_ptb = key_ptb.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ value_ptb = value_ptb.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ # create a full mask with all entries set to 0
+ seq_len = query_ptb.size(2)
+ full_mask = torch.zeros((seq_len, seq_len), device=query_ptb.device, dtype=query_ptb.dtype)
+
+ # set the attention value between image patches to -inf
+ full_mask[:identity_block_size, :identity_block_size] = float("-inf")
+
+ # set the diagonal of the attention value between image patches to 0
+ full_mask[:identity_block_size, :identity_block_size].fill_diagonal_(0)
+
+ # expand the mask to match the attention weights shape
+ full_mask = full_mask.unsqueeze(0).unsqueeze(0) # Add batch and num_heads dimensions
+
+ hidden_states_ptb = F.scaled_dot_product_attention(
+ query_ptb, key_ptb, value_ptb, attn_mask=full_mask, dropout_p=0.0, is_causal=False
+ )
+ hidden_states_ptb = hidden_states_ptb.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+ hidden_states_ptb = hidden_states_ptb.to(query_ptb.dtype)
+
+ # split the attention outputs.
+ hidden_states_ptb, encoder_hidden_states_ptb = (
+ hidden_states_ptb[:, : residual.shape[1]],
+ hidden_states_ptb[:, residual.shape[1] :],
+ )
+
+ # linear proj
+ hidden_states_ptb = attn.to_out[0](hidden_states_ptb)
+ # dropout
+ hidden_states_ptb = attn.to_out[1](hidden_states_ptb)
+ if not attn.context_pre_only:
+ encoder_hidden_states_ptb = attn.to_add_out(encoder_hidden_states_ptb)
+
+ if input_ndim == 4:
+ hidden_states_ptb = hidden_states_ptb.transpose(-1, -2).reshape(batch_size, channel, height, width)
+ if context_input_ndim == 4:
+ encoder_hidden_states_ptb = encoder_hidden_states_ptb.transpose(-1, -2).reshape(
+ batch_size, channel, height, width
+ )
+
+ ################ concat ###############
+ hidden_states = torch.cat([hidden_states_org, hidden_states_ptb])
+ encoder_hidden_states = torch.cat([encoder_hidden_states_org, encoder_hidden_states_ptb])
+
+ return hidden_states, encoder_hidden_states
+
+
+class FusedJointAttnProcessor2_0:
+ """Attention processor used typically in processing the SD3-like self-attention projections."""
+
+ def __init__(self):
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: torch.FloatTensor,
+ encoder_hidden_states: torch.FloatTensor = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ *args,
+ **kwargs,
+ ) -> torch.FloatTensor:
+ residual = hidden_states
+
+ input_ndim = hidden_states.ndim
+ if input_ndim == 4:
+ batch_size, channel, height, width = hidden_states.shape
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+ context_input_ndim = encoder_hidden_states.ndim
+ if context_input_ndim == 4:
+ batch_size, channel, height, width = encoder_hidden_states.shape
+ encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+
+ batch_size = encoder_hidden_states.shape[0]
+
+ # `sample` projections.
+ qkv = attn.to_qkv(hidden_states)
+ split_size = qkv.shape[-1] // 3
+ query, key, value = torch.split(qkv, split_size, dim=-1)
+
+ # `context` projections.
+ encoder_qkv = attn.to_added_qkv(encoder_hidden_states)
+ split_size = encoder_qkv.shape[-1] // 3
+ (
+ encoder_hidden_states_query_proj,
+ encoder_hidden_states_key_proj,
+ encoder_hidden_states_value_proj,
+ ) = torch.split(encoder_qkv, split_size, dim=-1)
+
+ # attention
+ query = torch.cat([query, encoder_hidden_states_query_proj], dim=1)
+ key = torch.cat([key, encoder_hidden_states_key_proj], dim=1)
+ value = torch.cat([value, encoder_hidden_states_value_proj], dim=1)
+
+ inner_dim = key.shape[-1]
+ head_dim = inner_dim // attn.heads
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+ hidden_states = hidden_states.to(query.dtype)
+
+ # Split the attention outputs.
+ hidden_states, encoder_hidden_states = (
+ hidden_states[:, : residual.shape[1]],
+ hidden_states[:, residual.shape[1] :],
+ )
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+ if not attn.context_pre_only:
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
+
+ if input_ndim == 4:
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
+ if context_input_ndim == 4:
+ encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
+
+ return hidden_states, encoder_hidden_states
+
+
+class AuraFlowAttnProcessor2_0:
+ """Attention processor used typically in processing Aura Flow."""
+
+ def __init__(self):
+ if not hasattr(F, "scaled_dot_product_attention") and is_torch_version("<", "2.1"):
+ raise ImportError(
+ "AuraFlowAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to at least 2.1 or above as we use `scale` in `F.scaled_dot_product_attention()`. "
+ )
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: torch.FloatTensor,
+ encoder_hidden_states: torch.FloatTensor = None,
+ *args,
+ **kwargs,
+ ) -> torch.FloatTensor:
+ batch_size = hidden_states.shape[0]
+
+ # `sample` projections.
+ query = attn.to_q(hidden_states)
+ key = attn.to_k(hidden_states)
+ value = attn.to_v(hidden_states)
+
+ # `context` projections.
+ if encoder_hidden_states is not None:
+ encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
+
+ # Reshape.
+ inner_dim = key.shape[-1]
+ head_dim = inner_dim // attn.heads
+ query = query.view(batch_size, -1, attn.heads, head_dim)
+ key = key.view(batch_size, -1, attn.heads, head_dim)
+ value = value.view(batch_size, -1, attn.heads, head_dim)
+
+ # Apply QK norm.
+ if attn.norm_q is not None:
+ query = attn.norm_q(query)
+ if attn.norm_k is not None:
+ key = attn.norm_k(key)
+
+ # Concatenate the projections.
+ if encoder_hidden_states is not None:
+ encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
+ batch_size, -1, attn.heads, head_dim
+ )
+ encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(batch_size, -1, attn.heads, head_dim)
+ encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
+ batch_size, -1, attn.heads, head_dim
+ )
+
+ if attn.norm_added_q is not None:
+ encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
+ if attn.norm_added_k is not None:
+ encoder_hidden_states_key_proj = attn.norm_added_q(encoder_hidden_states_key_proj)
+
+ query = torch.cat([encoder_hidden_states_query_proj, query], dim=1)
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
+
+ query = query.transpose(1, 2)
+ key = key.transpose(1, 2)
+ value = value.transpose(1, 2)
+
+ # Attention.
+ hidden_states = F.scaled_dot_product_attention(
+ query, key, value, dropout_p=0.0, scale=attn.scale, is_causal=False
+ )
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+ hidden_states = hidden_states.to(query.dtype)
+
+ # Split the attention outputs.
+ if encoder_hidden_states is not None:
+ hidden_states, encoder_hidden_states = (
+ hidden_states[:, encoder_hidden_states.shape[1] :],
+ hidden_states[:, : encoder_hidden_states.shape[1]],
+ )
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+ if encoder_hidden_states is not None:
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
+
+ if encoder_hidden_states is not None:
+ return hidden_states, encoder_hidden_states
+ else:
+ return hidden_states
+
+
+class FusedAuraFlowAttnProcessor2_0:
+ """Attention processor used typically in processing Aura Flow with fused projections."""
+
+ def __init__(self):
+ if not hasattr(F, "scaled_dot_product_attention") and is_torch_version("<", "2.1"):
+ raise ImportError(
+ "FusedAuraFlowAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to at least 2.1 or above as we use `scale` in `F.scaled_dot_product_attention()`. "
+ )
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: torch.FloatTensor,
+ encoder_hidden_states: torch.FloatTensor = None,
+ *args,
+ **kwargs,
+ ) -> torch.FloatTensor:
+ batch_size = hidden_states.shape[0]
+
+ # `sample` projections.
+ qkv = attn.to_qkv(hidden_states)
+ split_size = qkv.shape[-1] // 3
+ query, key, value = torch.split(qkv, split_size, dim=-1)
+
+ # `context` projections.
+ if encoder_hidden_states is not None:
+ encoder_qkv = attn.to_added_qkv(encoder_hidden_states)
+ split_size = encoder_qkv.shape[-1] // 3
+ (
+ encoder_hidden_states_query_proj,
+ encoder_hidden_states_key_proj,
+ encoder_hidden_states_value_proj,
+ ) = torch.split(encoder_qkv, split_size, dim=-1)
+
+ # Reshape.
+ inner_dim = key.shape[-1]
+ head_dim = inner_dim // attn.heads
+ query = query.view(batch_size, -1, attn.heads, head_dim)
+ key = key.view(batch_size, -1, attn.heads, head_dim)
+ value = value.view(batch_size, -1, attn.heads, head_dim)
+
+ # Apply QK norm.
+ if attn.norm_q is not None:
+ query = attn.norm_q(query)
+ if attn.norm_k is not None:
+ key = attn.norm_k(key)
+
+ # Concatenate the projections.
+ if encoder_hidden_states is not None:
+ encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
+ batch_size, -1, attn.heads, head_dim
+ )
+ encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(batch_size, -1, attn.heads, head_dim)
+ encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
+ batch_size, -1, attn.heads, head_dim
+ )
+
+ if attn.norm_added_q is not None:
+ encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
+ if attn.norm_added_k is not None:
+ encoder_hidden_states_key_proj = attn.norm_added_q(encoder_hidden_states_key_proj)
+
+ query = torch.cat([encoder_hidden_states_query_proj, query], dim=1)
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
+
+ query = query.transpose(1, 2)
+ key = key.transpose(1, 2)
+ value = value.transpose(1, 2)
+
+ # Attention.
+ hidden_states = F.scaled_dot_product_attention(
+ query, key, value, dropout_p=0.0, scale=attn.scale, is_causal=False
+ )
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+ hidden_states = hidden_states.to(query.dtype)
+
+ # Split the attention outputs.
+ if encoder_hidden_states is not None:
+ hidden_states, encoder_hidden_states = (
+ hidden_states[:, encoder_hidden_states.shape[1] :],
+ hidden_states[:, : encoder_hidden_states.shape[1]],
+ )
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+ if encoder_hidden_states is not None:
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
+
+ if encoder_hidden_states is not None:
+ return hidden_states, encoder_hidden_states
+ else:
+ return hidden_states
+
+
+class FluxAttnProcessor2_0:
+ """Attention processor used typically in processing the SD3-like self-attention projections."""
+
+ def __init__(self):
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError("FluxAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: torch.FloatTensor,
+ encoder_hidden_states: torch.FloatTensor = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ image_rotary_emb: Optional[torch.Tensor] = None,
+ ) -> torch.FloatTensor:
+ batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+
+ # `sample` projections.
+ query = attn.to_q(hidden_states)
+ key = attn.to_k(hidden_states)
+ value = attn.to_v(hidden_states)
+
+ inner_dim = key.shape[-1]
+ head_dim = inner_dim // attn.heads
+
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ if attn.norm_q is not None:
+ query = attn.norm_q(query)
+ if attn.norm_k is not None:
+ key = attn.norm_k(key)
+
+ # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
+ if encoder_hidden_states is not None:
+ # `context` projections.
+ encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
+
+ encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
+ batch_size, -1, attn.heads, head_dim
+ ).transpose(1, 2)
+ encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
+ batch_size, -1, attn.heads, head_dim
+ ).transpose(1, 2)
+ encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
+ batch_size, -1, attn.heads, head_dim
+ ).transpose(1, 2)
+
+ if attn.norm_added_q is not None:
+ encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
+ if attn.norm_added_k is not None:
+ encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
+
+ # attention
+ query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
+
+ if image_rotary_emb is not None:
+ from .embeddings import apply_rotary_emb
+
+ query = apply_rotary_emb(query, image_rotary_emb)
+ key = apply_rotary_emb(key, image_rotary_emb)
+
+ hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+ hidden_states = hidden_states.to(query.dtype)
+
+ if encoder_hidden_states is not None:
+ encoder_hidden_states, hidden_states = (
+ hidden_states[:, : encoder_hidden_states.shape[1]],
+ hidden_states[:, encoder_hidden_states.shape[1] :],
+ )
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
+
+ return hidden_states, encoder_hidden_states
+ else:
+ return hidden_states
+
+
+class FusedFluxAttnProcessor2_0:
+ """Attention processor used typically in processing the SD3-like self-attention projections."""
+
+ def __init__(self):
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError(
+ "FusedFluxAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
+ )
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: torch.FloatTensor,
+ encoder_hidden_states: torch.FloatTensor = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ image_rotary_emb: Optional[torch.Tensor] = None,
+ ) -> torch.FloatTensor:
+ batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+
+ # `sample` projections.
+ qkv = attn.to_qkv(hidden_states)
+ split_size = qkv.shape[-1] // 3
+ query, key, value = torch.split(qkv, split_size, dim=-1)
+
+ inner_dim = key.shape[-1]
+ head_dim = inner_dim // attn.heads
+
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ if attn.norm_q is not None:
+ query = attn.norm_q(query)
+ if attn.norm_k is not None:
+ key = attn.norm_k(key)
+
+ # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
+ # `context` projections.
+ if encoder_hidden_states is not None:
+ encoder_qkv = attn.to_added_qkv(encoder_hidden_states)
+ split_size = encoder_qkv.shape[-1] // 3
+ (
+ encoder_hidden_states_query_proj,
+ encoder_hidden_states_key_proj,
+ encoder_hidden_states_value_proj,
+ ) = torch.split(encoder_qkv, split_size, dim=-1)
+
+ encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
+ batch_size, -1, attn.heads, head_dim
+ ).transpose(1, 2)
+ encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
+ batch_size, -1, attn.heads, head_dim
+ ).transpose(1, 2)
+ encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
+ batch_size, -1, attn.heads, head_dim
+ ).transpose(1, 2)
+
+ if attn.norm_added_q is not None:
+ encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
+ if attn.norm_added_k is not None:
+ encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
+
+ # attention
+ query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
+
+ if image_rotary_emb is not None:
+ from .embeddings import apply_rotary_emb
+
+ query = apply_rotary_emb(query, image_rotary_emb)
+ key = apply_rotary_emb(key, image_rotary_emb)
+
+ hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+ hidden_states = hidden_states.to(query.dtype)
+
+ if encoder_hidden_states is not None:
+ encoder_hidden_states, hidden_states = (
+ hidden_states[:, : encoder_hidden_states.shape[1]],
+ hidden_states[:, encoder_hidden_states.shape[1] :],
+ )
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
+
+ return hidden_states, encoder_hidden_states
+ else:
+ return hidden_states
+
+
+class CogVideoXAttnProcessor2_0:
+ r"""
+ Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
+ query and key vectors, but does not include spatial normalization.
+ """
+
+ def __init__(self):
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ image_rotary_emb: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ text_seq_length = encoder_hidden_states.size(1)
+ latent_seq_length = hidden_states.size(1)
+
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
+
+ batch_size, sequence_length, _ = (
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+ )
+
+ if attention_mask is not None:
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
+
+ query = attn.to_q(hidden_states)
+ key = attn.to_k(hidden_states)
+ value = attn.to_v(hidden_states)
+
+ inner_dim = key.shape[-1]
+ head_dim = inner_dim // attn.heads
+
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ if attn.norm_q is not None:
+ query = attn.norm_q(query)
+ if attn.norm_k is not None:
+ key = attn.norm_k(key)
+
+ # Apply RoPE if needed
+ if image_rotary_emb is not None:
+ from .embeddings import apply_rotary_emb
+
+ query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb)
+ if not attn.is_cross_attention:
+ key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)
+
+ if get_world_size() > 1:
+ hidden_states = gather_parrellel_ga(query,key,value,1.0 / math.sqrt(query.shape[-1]),get_world_size())
+ else:
+ hidden_states = torch_npu.npu_prompt_flash_attention(
+ query, key, value, num_heads=attn.heads,
+ input_layout='BNSD',
+ scale_value=1.0 / math.sqrt(query.shape[-1]),
+ atten_mask=attention_mask,
+ pre_tokens=MAX_TOKENS,
+ next_tokens=MAX_TOKENS,
+ sparse_mode=0
+ )
+
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ encoder_hidden_states, hidden_states = hidden_states.split(
+ [text_seq_length, latent_seq_length], dim=1
+ )
+ return hidden_states, encoder_hidden_states
+
+def gather_parrellel_ga(
+ q,k,v,
+ scale_value,
+ world_size,
+ num_head_split=8,
+):
+ """
+ All Gather key-value pairs in parallel for Flash attention .
+
+
+ Args:
+ qkv_list (List[torch.Tensor]): A list containing query (q), key (k), and value (v) tensors.
+ the key and value should in the shape [B N S D]
+ head_dim (int): The dimension of each attention head.
+ world_size (int): The number of distributed processes.
+ num_head_split (int, optional): The number of splits for the attention heads. Defaults to 8.
+
+ Returns:
+ torch.Tensor: The output tensor after applying parallel attention.
+ The shape [B N S D]
+ """
+ q_list = q.chunk(num_head_split, dim=1)
+
+ kv = torch.cat((k, v), dim=0)
+ kv_list = kv.chunk(num_head_split, dim=1)
+ kv_split = kv_list[0].contiguous()
+ b, n, s, d = kv_split.shape
+ kv_full = torch.empty([world_size, b, n, s, d], dtype=kv_split.dtype, device=kv_split.device)
+ torch.distributed.all_gather_into_tensor(kv_full, kv_split)
+ kv_full = kv_full.permute(1, 2, 0, 3, 4).reshape(b, n, -1, d)
+
+ out = []
+ for step in range(num_head_split):
+ k, v = kv_full.chunk(2, dim=0)
+ if step != num_head_split - 1:
+ kv_split = kv_list[step + 1].contiguous()
+ b, n, s, d = kv_split.shape
+ kv_full = torch.empty([world_size, b, n, s, d], dtype=kv_split.dtype, device=kv_split.device)
+ req = torch.distributed.all_gather_into_tensor(kv_full, kv_split, async_op=True)
+
+ output = torch_npu.npu_prompt_flash_attention(
+ q_list[step], k, v,
+ num_heads=k.shape[1],
+ input_layout="BNSD",
+ scale_value=scale_value,
+ pre_tokens=MAX_TOKENS,
+ next_tokens=MAX_TOKENS
+ )
+
+ out.append(output)
+
+ if step != num_head_split - 1:
+ req.wait()
+ kv_full = kv_full.permute(1, 2, 0, 3, 4).reshape(b, n, -1, d)
+ out = torch.cat(out, dim=1)
+ return out
+
+class FusedCogVideoXAttnProcessor2_0:
+ r"""
+ Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
+ query and key vectors, but does not include spatial normalization.
+ """
+
+ def __init__(self):
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ image_rotary_emb: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ text_seq_length = encoder_hidden_states.size(1)
+ latent_seq_length = hidden_states.size(1)
+
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
+
+ batch_size, sequence_length, _ = (
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+ )
+
+ if attention_mask is not None:
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
+
+ qkv = attn.to_qkv(hidden_states)
+ split_size = qkv.shape[-1] // 3
+ query, key, value = torch.split(qkv, split_size, dim=-1)
+
+ inner_dim = key.shape[-1]
+ head_dim = inner_dim // attn.heads
+
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ if attn.norm_q is not None:
+ query = attn.norm_q(query)
+ if attn.norm_k is not None:
+ key = attn.norm_k(key)
+
+ # Apply RoPE if needed
+ if image_rotary_emb is not None:
+ from .embeddings import apply_rotary_emb
+
+ query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb)
+ if not attn.is_cross_attention:
+ key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)
+
+ hidden_states = F.scaled_dot_product_attention(
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
+ )
+
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ encoder_hidden_states, hidden_states = hidden_states.split(
+ [text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
+ )
+ return hidden_states, encoder_hidden_states
+
+
+class XFormersAttnAddedKVProcessor:
+ r"""
+ Processor for implementing memory efficient attention using xFormers.
+
+ Args:
+ attention_op (`Callable`, *optional*, defaults to `None`):
+ The base
+ [operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to
+ use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best
+ operator.
+ """
+
+ def __init__(self, attention_op: Optional[Callable] = None):
+ self.attention_op = attention_op
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ residual = hidden_states
+ hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
+ batch_size, sequence_length, _ = hidden_states.shape
+
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+ elif attn.norm_cross:
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+ query = attn.to_q(hidden_states)
+ query = attn.head_to_batch_dim(query)
+
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
+ encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
+ encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
+
+ if not attn.only_cross_attention:
+ key = attn.to_k(hidden_states)
+ value = attn.to_v(hidden_states)
+ key = attn.head_to_batch_dim(key)
+ value = attn.head_to_batch_dim(value)
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
+ else:
+ key = encoder_hidden_states_key_proj
+ value = encoder_hidden_states_value_proj
+
+ hidden_states = xformers.ops.memory_efficient_attention(
+ query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
+ )
+ hidden_states = hidden_states.to(query.dtype)
+ hidden_states = attn.batch_to_head_dim(hidden_states)
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
+ hidden_states = hidden_states + residual
+
+ return hidden_states
+
+
+class XFormersAttnProcessor:
+ r"""
+ Processor for implementing memory efficient attention using xFormers.
+
+ Args:
+ attention_op (`Callable`, *optional*, defaults to `None`):
+ The base
+ [operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to
+ use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best
+ operator.
+ """
+
+ def __init__(self, attention_op: Optional[Callable] = None):
+ self.attention_op = attention_op
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ temb: Optional[torch.Tensor] = None,
+ *args,
+ **kwargs,
+ ) -> torch.Tensor:
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
+ deprecate("scale", "1.0.0", deprecation_message)
+
+ residual = hidden_states
+
+ if attn.spatial_norm is not None:
+ hidden_states = attn.spatial_norm(hidden_states, temb)
+
+ input_ndim = hidden_states.ndim
+
+ if input_ndim == 4:
+ batch_size, channel, height, width = hidden_states.shape
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+
+ batch_size, key_tokens, _ = (
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+ )
+
+ attention_mask = attn.prepare_attention_mask(attention_mask, key_tokens, batch_size)
+ if attention_mask is not None:
+ _, query_tokens, _ = hidden_states.shape
+ attention_mask = attention_mask.expand(-1, query_tokens, -1)
+
+ if attn.group_norm is not None:
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+ query = attn.to_q(hidden_states)
+
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+ elif attn.norm_cross:
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+ key = attn.to_k(encoder_hidden_states)
+ value = attn.to_v(encoder_hidden_states)
+
+ query = attn.head_to_batch_dim(query).contiguous()
+ key = attn.head_to_batch_dim(key).contiguous()
+ value = attn.head_to_batch_dim(value).contiguous()
+
+ hidden_states = xformers.ops.memory_efficient_attention(
+ query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
+ )
+ hidden_states = hidden_states.to(query.dtype)
+ hidden_states = attn.batch_to_head_dim(hidden_states)
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ if input_ndim == 4:
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
+
+ if attn.residual_connection:
+ hidden_states = hidden_states + residual
+
+ hidden_states = hidden_states / attn.rescale_output_factor
+
+ return hidden_states
+
+
+class AttnProcessorNPU:
+ r"""
+ Processor for implementing flash attention using torch_npu. Torch_npu supports only fp16 and bf16 data types. If
+ fp32 is used, F.scaled_dot_product_attention will be used for computation, but the acceleration effect on NPU is
+ not significant.
+
+ """
+
+ def __init__(self):
+ if not is_torch_npu_available():
+ raise ImportError("AttnProcessorNPU requires torch_npu extensions and is supported only on npu devices.")
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ temb: Optional[torch.Tensor] = None,
+ *args,
+ **kwargs,
+ ) -> torch.Tensor:
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
+ deprecate("scale", "1.0.0", deprecation_message)
+
+ residual = hidden_states
+ if attn.spatial_norm is not None:
+ hidden_states = attn.spatial_norm(hidden_states, temb)
+
+ input_ndim = hidden_states.ndim
+
+ if input_ndim == 4:
+ batch_size, channel, height, width = hidden_states.shape
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+
+ batch_size, sequence_length, _ = (
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+ )
+
+ if attention_mask is not None:
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
+
+ if attn.group_norm is not None:
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+ query = attn.to_q(hidden_states)
+
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+ elif attn.norm_cross:
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+ key = attn.to_k(encoder_hidden_states)
+ value = attn.to_v(encoder_hidden_states)
+
+ inner_dim = key.shape[-1]
+ head_dim = inner_dim // attn.heads
+
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ if query.dtype in (torch.float16, torch.bfloat16):
+ hidden_states = torch_npu.npu_fusion_attention(
+ query,
+ key,
+ value,
+ attn.heads,
+ input_layout="BNSD",
+ pse=None,
+ atten_mask=attention_mask,
+ scale=1.0 / math.sqrt(query.shape[-1]),
+ pre_tockens=65536,
+ next_tockens=65536,
+ keep_prob=1.0,
+ sync=False,
+ inner_precise=0,
+ )[0]
+ else:
+ hidden_states = F.scaled_dot_product_attention(
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
+ )
+
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+ hidden_states = hidden_states.to(query.dtype)
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ if input_ndim == 4:
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
+
+ if attn.residual_connection:
+ hidden_states = hidden_states + residual
+
+ hidden_states = hidden_states / attn.rescale_output_factor
+
+ return hidden_states
+
+
+class AttnProcessor2_0:
+ r"""
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
+ """
+
+ def __init__(self):
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ temb: Optional[torch.Tensor] = None,
+ *args,
+ **kwargs,
+ ) -> torch.Tensor:
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
+ deprecate("scale", "1.0.0", deprecation_message)
+
+ residual = hidden_states
+ if attn.spatial_norm is not None:
+ hidden_states = attn.spatial_norm(hidden_states, temb)
+
+ input_ndim = hidden_states.ndim
+
+ if input_ndim == 4:
+ batch_size, channel, height, width = hidden_states.shape
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+
+ batch_size, sequence_length, _ = (
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+ )
+
+ if attention_mask is not None:
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
+
+ if attn.group_norm is not None:
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+ query = attn.to_q(hidden_states)
+
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+ elif attn.norm_cross:
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+ key = attn.to_k(encoder_hidden_states)
+ value = attn.to_v(encoder_hidden_states)
+
+ inner_dim = key.shape[-1]
+ head_dim = inner_dim // attn.heads
+
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ if attn.norm_q is not None:
+ query = attn.norm_q(query)
+ if attn.norm_k is not None:
+ key = attn.norm_k(key)
+
+ hidden_states = F.scaled_dot_product_attention(
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
+ )
+
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+ hidden_states = hidden_states.to(query.dtype)
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ if input_ndim == 4:
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
+
+ if attn.residual_connection:
+ hidden_states = hidden_states + residual
+
+ hidden_states = hidden_states / attn.rescale_output_factor
+
+ return hidden_states
+
+
+class StableAudioAttnProcessor2_0:
+ r"""
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
+ used in the Stable Audio model. It applies rotary embedding on query and key vector, and allows MHA, GQA or MQA.
+ """
+
+ def __init__(self):
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError(
+ "StableAudioAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
+ )
+
+ def apply_partial_rotary_emb(
+ self,
+ x: torch.Tensor,
+ freqs_cis: Tuple[torch.Tensor],
+ ) -> torch.Tensor:
+ from .embeddings import apply_rotary_emb
+
+ rot_dim = freqs_cis[0].shape[-1]
+ x_to_rotate, x_unrotated = x[..., :rot_dim], x[..., rot_dim:]
+
+ x_rotated = apply_rotary_emb(x_to_rotate, freqs_cis, use_real=True, use_real_unbind_dim=-2)
+
+ out = torch.cat((x_rotated, x_unrotated), dim=-1)
+ return out
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ rotary_emb: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ from .embeddings import apply_rotary_emb
+
+ residual = hidden_states
+
+ input_ndim = hidden_states.ndim
+
+ if input_ndim == 4:
+ batch_size, channel, height, width = hidden_states.shape
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+
+ batch_size, sequence_length, _ = (
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+ )
+
+ if attention_mask is not None:
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
+
+ query = attn.to_q(hidden_states)
+
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+ elif attn.norm_cross:
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+ key = attn.to_k(encoder_hidden_states)
+ value = attn.to_v(encoder_hidden_states)
+
+ head_dim = query.shape[-1] // attn.heads
+ kv_heads = key.shape[-1] // head_dim
+
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ key = key.view(batch_size, -1, kv_heads, head_dim).transpose(1, 2)
+ value = value.view(batch_size, -1, kv_heads, head_dim).transpose(1, 2)
+
+ if kv_heads != attn.heads:
+ # if GQA or MQA, repeat the key/value heads to reach the number of query heads.
+ heads_per_kv_head = attn.heads // kv_heads
+ key = torch.repeat_interleave(key, heads_per_kv_head, dim=1)
+ value = torch.repeat_interleave(value, heads_per_kv_head, dim=1)
+
+ if attn.norm_q is not None:
+ query = attn.norm_q(query)
+ if attn.norm_k is not None:
+ key = attn.norm_k(key)
+
+ # Apply RoPE if needed
+ if rotary_emb is not None:
+ query_dtype = query.dtype
+ key_dtype = key.dtype
+ query = query.to(torch.float32)
+ key = key.to(torch.float32)
+
+ rot_dim = rotary_emb[0].shape[-1]
+ query_to_rotate, query_unrotated = query[..., :rot_dim], query[..., rot_dim:]
+ query_rotated = apply_rotary_emb(query_to_rotate, rotary_emb, use_real=True, use_real_unbind_dim=-2)
+
+ query = torch.cat((query_rotated, query_unrotated), dim=-1)
+
+ if not attn.is_cross_attention:
+ key_to_rotate, key_unrotated = key[..., :rot_dim], key[..., rot_dim:]
+ key_rotated = apply_rotary_emb(key_to_rotate, rotary_emb, use_real=True, use_real_unbind_dim=-2)
+
+ key = torch.cat((key_rotated, key_unrotated), dim=-1)
+
+ query = query.to(query_dtype)
+ key = key.to(key_dtype)
+
+ hidden_states = F.scaled_dot_product_attention(
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
+ )
+
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+ hidden_states = hidden_states.to(query.dtype)
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ if input_ndim == 4:
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
+
+ if attn.residual_connection:
+ hidden_states = hidden_states + residual
+
+ hidden_states = hidden_states / attn.rescale_output_factor
+
+ return hidden_states
+
+
+class HunyuanAttnProcessor2_0:
+ r"""
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
+ used in the HunyuanDiT model. It applies a s normalization layer and rotary embedding on query and key vector.
+ """
+
+ def __init__(self):
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ temb: Optional[torch.Tensor] = None,
+ image_rotary_emb: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ from .embeddings import apply_rotary_emb
+
+ residual = hidden_states
+ if attn.spatial_norm is not None:
+ hidden_states = attn.spatial_norm(hidden_states, temb)
+
+ input_ndim = hidden_states.ndim
+
+ if input_ndim == 4:
+ batch_size, channel, height, width = hidden_states.shape
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+
+ batch_size, sequence_length, _ = (
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+ )
+
+ if attention_mask is not None:
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
+
+ if attn.group_norm is not None:
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+ query = attn.to_q(hidden_states)
+
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+ elif attn.norm_cross:
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+ key = attn.to_k(encoder_hidden_states)
+ value = attn.to_v(encoder_hidden_states)
+
+ inner_dim = key.shape[-1]
+ head_dim = inner_dim // attn.heads
+
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ if attn.norm_q is not None:
+ query = attn.norm_q(query)
+ if attn.norm_k is not None:
+ key = attn.norm_k(key)
+
+ # Apply RoPE if needed
+ if image_rotary_emb is not None:
+ query = apply_rotary_emb(query, image_rotary_emb)
+ if not attn.is_cross_attention:
+ key = apply_rotary_emb(key, image_rotary_emb)
+
+ hidden_states = F.scaled_dot_product_attention(
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
+ )
+
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+ hidden_states = hidden_states.to(query.dtype)
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ if input_ndim == 4:
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
+
+ if attn.residual_connection:
+ hidden_states = hidden_states + residual
+
+ hidden_states = hidden_states / attn.rescale_output_factor
+
+ return hidden_states
+
+
+class FusedHunyuanAttnProcessor2_0:
+ r"""
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0) with fused
+ projection layers. This is used in the HunyuanDiT model. It applies a s normalization layer and rotary embedding on
+ query and key vector.
+ """
+
+ def __init__(self):
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError(
+ "FusedHunyuanAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
+ )
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ temb: Optional[torch.Tensor] = None,
+ image_rotary_emb: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ from .embeddings import apply_rotary_emb
+
+ residual = hidden_states
+ if attn.spatial_norm is not None:
+ hidden_states = attn.spatial_norm(hidden_states, temb)
+
+ input_ndim = hidden_states.ndim
+
+ if input_ndim == 4:
+ batch_size, channel, height, width = hidden_states.shape
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+
+ batch_size, sequence_length, _ = (
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+ )
+
+ if attention_mask is not None:
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
+
+ if attn.group_norm is not None:
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+ if encoder_hidden_states is None:
+ qkv = attn.to_qkv(hidden_states)
+ split_size = qkv.shape[-1] // 3
+ query, key, value = torch.split(qkv, split_size, dim=-1)
+ else:
+ if attn.norm_cross:
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+ query = attn.to_q(hidden_states)
+
+ kv = attn.to_kv(encoder_hidden_states)
+ split_size = kv.shape[-1] // 2
+ key, value = torch.split(kv, split_size, dim=-1)
+
+ inner_dim = key.shape[-1]
+ head_dim = inner_dim // attn.heads
+
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ if attn.norm_q is not None:
+ query = attn.norm_q(query)
+ if attn.norm_k is not None:
+ key = attn.norm_k(key)
+
+ # Apply RoPE if needed
+ if image_rotary_emb is not None:
+ query = apply_rotary_emb(query, image_rotary_emb)
+ if not attn.is_cross_attention:
+ key = apply_rotary_emb(key, image_rotary_emb)
+
+ hidden_states = F.scaled_dot_product_attention(
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
+ )
+
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+ hidden_states = hidden_states.to(query.dtype)
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ if input_ndim == 4:
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
+
+ if attn.residual_connection:
+ hidden_states = hidden_states + residual
+
+ hidden_states = hidden_states / attn.rescale_output_factor
+
+ return hidden_states
+
+
+class PAGHunyuanAttnProcessor2_0:
+ r"""
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
+ used in the HunyuanDiT model. It applies a normalization layer and rotary embedding on query and key vector. This
+ variant of the processor employs [Pertubed Attention Guidance](https://arxiv.org/abs/2403.17377).
+ """
+
+ def __init__(self):
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError(
+ "PAGHunyuanAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
+ )
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ temb: Optional[torch.Tensor] = None,
+ image_rotary_emb: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ from .embeddings import apply_rotary_emb
+
+ residual = hidden_states
+ if attn.spatial_norm is not None:
+ hidden_states = attn.spatial_norm(hidden_states, temb)
+
+ input_ndim = hidden_states.ndim
+
+ if input_ndim == 4:
+ batch_size, channel, height, width = hidden_states.shape
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+
+ # chunk
+ hidden_states_org, hidden_states_ptb = hidden_states.chunk(2)
+
+ # 1. Original Path
+ batch_size, sequence_length, _ = (
+ hidden_states_org.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+ )
+
+ if attention_mask is not None:
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
+
+ if attn.group_norm is not None:
+ hidden_states_org = attn.group_norm(hidden_states_org.transpose(1, 2)).transpose(1, 2)
+
+ query = attn.to_q(hidden_states_org)
+
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states_org
+ elif attn.norm_cross:
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+ key = attn.to_k(encoder_hidden_states)
+ value = attn.to_v(encoder_hidden_states)
+
+ inner_dim = key.shape[-1]
+ head_dim = inner_dim // attn.heads
+
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ if attn.norm_q is not None:
+ query = attn.norm_q(query)
+ if attn.norm_k is not None:
+ key = attn.norm_k(key)
+
+ # Apply RoPE if needed
+ if image_rotary_emb is not None:
+ query = apply_rotary_emb(query, image_rotary_emb)
+ if not attn.is_cross_attention:
+ key = apply_rotary_emb(key, image_rotary_emb)
+
+ hidden_states_org = F.scaled_dot_product_attention(
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
+ )
+
+ hidden_states_org = hidden_states_org.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+ hidden_states_org = hidden_states_org.to(query.dtype)
+
+ # linear proj
+ hidden_states_org = attn.to_out[0](hidden_states_org)
+ # dropout
+ hidden_states_org = attn.to_out[1](hidden_states_org)
+
+ if input_ndim == 4:
+ hidden_states_org = hidden_states_org.transpose(-1, -2).reshape(batch_size, channel, height, width)
+
+ # 2. Perturbed Path
+ if attn.group_norm is not None:
+ hidden_states_ptb = attn.group_norm(hidden_states_ptb.transpose(1, 2)).transpose(1, 2)
+
+ hidden_states_ptb = attn.to_v(hidden_states_ptb)
+ hidden_states_ptb = hidden_states_ptb.to(query.dtype)
+
+ # linear proj
+ hidden_states_ptb = attn.to_out[0](hidden_states_ptb)
+ # dropout
+ hidden_states_ptb = attn.to_out[1](hidden_states_ptb)
+
+ if input_ndim == 4:
+ hidden_states_ptb = hidden_states_ptb.transpose(-1, -2).reshape(batch_size, channel, height, width)
+
+ # cat
+ hidden_states = torch.cat([hidden_states_org, hidden_states_ptb])
+
+ if attn.residual_connection:
+ hidden_states = hidden_states + residual
+
+ hidden_states = hidden_states / attn.rescale_output_factor
+
+ return hidden_states
+
+
+class PAGCFGHunyuanAttnProcessor2_0:
+ r"""
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
+ used in the HunyuanDiT model. It applies a normalization layer and rotary embedding on query and key vector. This
+ variant of the processor employs [Pertubed Attention Guidance](https://arxiv.org/abs/2403.17377).
+ """
+
+ def __init__(self):
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError(
+ "PAGCFGHunyuanAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
+ )
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ temb: Optional[torch.Tensor] = None,
+ image_rotary_emb: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ from .embeddings import apply_rotary_emb
+
+ residual = hidden_states
+ if attn.spatial_norm is not None:
+ hidden_states = attn.spatial_norm(hidden_states, temb)
+
+ input_ndim = hidden_states.ndim
+
+ if input_ndim == 4:
+ batch_size, channel, height, width = hidden_states.shape
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+
+ # chunk
+ hidden_states_uncond, hidden_states_org, hidden_states_ptb = hidden_states.chunk(3)
+ hidden_states_org = torch.cat([hidden_states_uncond, hidden_states_org])
+
+ # 1. Original Path
+ batch_size, sequence_length, _ = (
+ hidden_states_org.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+ )
+
+ if attention_mask is not None:
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
+
+ if attn.group_norm is not None:
+ hidden_states_org = attn.group_norm(hidden_states_org.transpose(1, 2)).transpose(1, 2)
+
+ query = attn.to_q(hidden_states_org)
+
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states_org
+ elif attn.norm_cross:
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+ key = attn.to_k(encoder_hidden_states)
+ value = attn.to_v(encoder_hidden_states)
+
+ inner_dim = key.shape[-1]
+ head_dim = inner_dim // attn.heads
+
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ if attn.norm_q is not None:
+ query = attn.norm_q(query)
+ if attn.norm_k is not None:
+ key = attn.norm_k(key)
+
+ # Apply RoPE if needed
+ if image_rotary_emb is not None:
+ query = apply_rotary_emb(query, image_rotary_emb)
+ if not attn.is_cross_attention:
+ key = apply_rotary_emb(key, image_rotary_emb)
+
+ hidden_states_org = F.scaled_dot_product_attention(
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
+ )
+
+ hidden_states_org = hidden_states_org.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+ hidden_states_org = hidden_states_org.to(query.dtype)
+
+ # linear proj
+ hidden_states_org = attn.to_out[0](hidden_states_org)
+ # dropout
+ hidden_states_org = attn.to_out[1](hidden_states_org)
+
+ if input_ndim == 4:
+ hidden_states_org = hidden_states_org.transpose(-1, -2).reshape(batch_size, channel, height, width)
+
+ # 2. Perturbed Path
+ if attn.group_norm is not None:
+ hidden_states_ptb = attn.group_norm(hidden_states_ptb.transpose(1, 2)).transpose(1, 2)
+
+ hidden_states_ptb = attn.to_v(hidden_states_ptb)
+ hidden_states_ptb = hidden_states_ptb.to(query.dtype)
+
+ # linear proj
+ hidden_states_ptb = attn.to_out[0](hidden_states_ptb)
+ # dropout
+ hidden_states_ptb = attn.to_out[1](hidden_states_ptb)
+
+ if input_ndim == 4:
+ hidden_states_ptb = hidden_states_ptb.transpose(-1, -2).reshape(batch_size, channel, height, width)
+
+ # cat
+ hidden_states = torch.cat([hidden_states_org, hidden_states_ptb])
+
+ if attn.residual_connection:
+ hidden_states = hidden_states + residual
+
+ hidden_states = hidden_states / attn.rescale_output_factor
+
+ return hidden_states
+
+
+class LuminaAttnProcessor2_0:
+ r"""
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
+ used in the LuminaNextDiT model. It applies a s normalization layer and rotary embedding on query and key vector.
+ """
+
+ def __init__(self):
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ query_rotary_emb: Optional[torch.Tensor] = None,
+ key_rotary_emb: Optional[torch.Tensor] = None,
+ base_sequence_length: Optional[int] = None,
+ ) -> torch.Tensor:
+ from .embeddings import apply_rotary_emb
+
+ input_ndim = hidden_states.ndim
+
+ if input_ndim == 4:
+ batch_size, channel, height, width = hidden_states.shape
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+
+ batch_size, sequence_length, _ = hidden_states.shape
+
+ # Get Query-Key-Value Pair
+ query = attn.to_q(hidden_states)
+ key = attn.to_k(encoder_hidden_states)
+ value = attn.to_v(encoder_hidden_states)
+
+ query_dim = query.shape[-1]
+ inner_dim = key.shape[-1]
+ head_dim = query_dim // attn.heads
+ dtype = query.dtype
+
+ # Get key-value heads
+ kv_heads = inner_dim // head_dim
+
+ # Apply Query-Key Norm if needed
+ if attn.norm_q is not None:
+ query = attn.norm_q(query)
+ if attn.norm_k is not None:
+ key = attn.norm_k(key)
+
+ query = query.view(batch_size, -1, attn.heads, head_dim)
+
+ key = key.view(batch_size, -1, kv_heads, head_dim)
+ value = value.view(batch_size, -1, kv_heads, head_dim)
+
+ # Apply RoPE if needed
+ if query_rotary_emb is not None:
+ query = apply_rotary_emb(query, query_rotary_emb, use_real=False)
+ if key_rotary_emb is not None:
+ key = apply_rotary_emb(key, key_rotary_emb, use_real=False)
+
+ query, key = query.to(dtype), key.to(dtype)
+
+ # Apply proportional attention if true
+ if key_rotary_emb is None:
+ softmax_scale = None
+ else:
+ if base_sequence_length is not None:
+ softmax_scale = math.sqrt(math.log(sequence_length, base_sequence_length)) * attn.scale
+ else:
+ softmax_scale = attn.scale
+
+ # perform Grouped-qurey Attention (GQA)
+ n_rep = attn.heads // kv_heads
+ if n_rep >= 1:
+ key = key.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
+ value = value.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
+
+ attention_mask = attention_mask.bool().view(batch_size, 1, 1, -1)
+ attention_mask = attention_mask.expand(-1, attn.heads, sequence_length, -1)
+
+ query = query.transpose(1, 2)
+ key = key.transpose(1, 2)
+ value = value.transpose(1, 2)
+
+ hidden_states = F.scaled_dot_product_attention(
+ query, key, value, attn_mask=attention_mask, scale=softmax_scale
+ )
+ hidden_states = hidden_states.transpose(1, 2).to(dtype)
+
+ return hidden_states
+
+
+class FusedAttnProcessor2_0:
+ r"""
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). It uses
+ fused projection layers. For self-attention modules, all projection matrices (i.e., query, key, value) are fused.
+ For cross-attention modules, key and value projection matrices are fused.
+
+
+
+ This API is currently 🧪 experimental in nature and can change in future.
+
+
+ """
+
+ def __init__(self):
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError(
+ "FusedAttnProcessor2_0 requires at least PyTorch 2.0, to use it. Please upgrade PyTorch to > 2.0."
+ )
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ temb: Optional[torch.Tensor] = None,
+ *args,
+ **kwargs,
+ ) -> torch.Tensor:
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
+ deprecate("scale", "1.0.0", deprecation_message)
+
+ residual = hidden_states
+ if attn.spatial_norm is not None:
+ hidden_states = attn.spatial_norm(hidden_states, temb)
+
+ input_ndim = hidden_states.ndim
+
+ if input_ndim == 4:
+ batch_size, channel, height, width = hidden_states.shape
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+
+ batch_size, sequence_length, _ = (
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+ )
+
+ if attention_mask is not None:
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
+
+ if attn.group_norm is not None:
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+ if encoder_hidden_states is None:
+ qkv = attn.to_qkv(hidden_states)
+ split_size = qkv.shape[-1] // 3
+ query, key, value = torch.split(qkv, split_size, dim=-1)
+ else:
+ if attn.norm_cross:
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+ query = attn.to_q(hidden_states)
+
+ kv = attn.to_kv(encoder_hidden_states)
+ split_size = kv.shape[-1] // 2
+ key, value = torch.split(kv, split_size, dim=-1)
+
+ inner_dim = key.shape[-1]
+ head_dim = inner_dim // attn.heads
+
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ if attn.norm_q is not None:
+ query = attn.norm_q(query)
+ if attn.norm_k is not None:
+ key = attn.norm_k(key)
+
+ hidden_states = F.scaled_dot_product_attention(
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
+ )
+
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+ hidden_states = hidden_states.to(query.dtype)
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ if input_ndim == 4:
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
+
+ if attn.residual_connection:
+ hidden_states = hidden_states + residual
+
+ hidden_states = hidden_states / attn.rescale_output_factor
+
+ return hidden_states
+
+
+class CustomDiffusionXFormersAttnProcessor(nn.Module):
+ r"""
+ Processor for implementing memory efficient attention using xFormers for the Custom Diffusion method.
+
+ Args:
+ train_kv (`bool`, defaults to `True`):
+ Whether to newly train the key and value matrices corresponding to the text features.
+ train_q_out (`bool`, defaults to `True`):
+ Whether to newly train query matrices corresponding to the latent image features.
+ hidden_size (`int`, *optional*, defaults to `None`):
+ The hidden size of the attention layer.
+ cross_attention_dim (`int`, *optional*, defaults to `None`):
+ The number of channels in the `encoder_hidden_states`.
+ out_bias (`bool`, defaults to `True`):
+ Whether to include the bias parameter in `train_q_out`.
+ dropout (`float`, *optional*, defaults to 0.0):
+ The dropout probability to use.
+ attention_op (`Callable`, *optional*, defaults to `None`):
+ The base
+ [operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to use
+ as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best operator.
+ """
+
+ def __init__(
+ self,
+ train_kv: bool = True,
+ train_q_out: bool = False,
+ hidden_size: Optional[int] = None,
+ cross_attention_dim: Optional[int] = None,
+ out_bias: bool = True,
+ dropout: float = 0.0,
+ attention_op: Optional[Callable] = None,
+ ):
+ super().__init__()
+ self.train_kv = train_kv
+ self.train_q_out = train_q_out
+
+ self.hidden_size = hidden_size
+ self.cross_attention_dim = cross_attention_dim
+ self.attention_op = attention_op
+
+ # `_custom_diffusion` id for easy serialization and loading.
+ if self.train_kv:
+ self.to_k_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
+ self.to_v_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
+ if self.train_q_out:
+ self.to_q_custom_diffusion = nn.Linear(hidden_size, hidden_size, bias=False)
+ self.to_out_custom_diffusion = nn.ModuleList([])
+ self.to_out_custom_diffusion.append(nn.Linear(hidden_size, hidden_size, bias=out_bias))
+ self.to_out_custom_diffusion.append(nn.Dropout(dropout))
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ batch_size, sequence_length, _ = (
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+ )
+
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+
+ if self.train_q_out:
+ query = self.to_q_custom_diffusion(hidden_states).to(attn.to_q.weight.dtype)
+ else:
+ query = attn.to_q(hidden_states.to(attn.to_q.weight.dtype))
+
+ if encoder_hidden_states is None:
+ crossattn = False
+ encoder_hidden_states = hidden_states
+ else:
+ crossattn = True
+ if attn.norm_cross:
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+ if self.train_kv:
+ key = self.to_k_custom_diffusion(encoder_hidden_states.to(self.to_k_custom_diffusion.weight.dtype))
+ value = self.to_v_custom_diffusion(encoder_hidden_states.to(self.to_v_custom_diffusion.weight.dtype))
+ key = key.to(attn.to_q.weight.dtype)
+ value = value.to(attn.to_q.weight.dtype)
+ else:
+ key = attn.to_k(encoder_hidden_states)
+ value = attn.to_v(encoder_hidden_states)
+
+ if crossattn:
+ detach = torch.ones_like(key)
+ detach[:, :1, :] = detach[:, :1, :] * 0.0
+ key = detach * key + (1 - detach) * key.detach()
+ value = detach * value + (1 - detach) * value.detach()
+
+ query = attn.head_to_batch_dim(query).contiguous()
+ key = attn.head_to_batch_dim(key).contiguous()
+ value = attn.head_to_batch_dim(value).contiguous()
+
+ hidden_states = xformers.ops.memory_efficient_attention(
+ query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
+ )
+ hidden_states = hidden_states.to(query.dtype)
+ hidden_states = attn.batch_to_head_dim(hidden_states)
+
+ if self.train_q_out:
+ # linear proj
+ hidden_states = self.to_out_custom_diffusion[0](hidden_states)
+ # dropout
+ hidden_states = self.to_out_custom_diffusion[1](hidden_states)
+ else:
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ return hidden_states
+
+
+class CustomDiffusionAttnProcessor2_0(nn.Module):
+ r"""
+ Processor for implementing attention for the Custom Diffusion method using PyTorch 2.0’s memory-efficient scaled
+ dot-product attention.
+
+ Args:
+ train_kv (`bool`, defaults to `True`):
+ Whether to newly train the key and value matrices corresponding to the text features.
+ train_q_out (`bool`, defaults to `True`):
+ Whether to newly train query matrices corresponding to the latent image features.
+ hidden_size (`int`, *optional*, defaults to `None`):
+ The hidden size of the attention layer.
+ cross_attention_dim (`int`, *optional*, defaults to `None`):
+ The number of channels in the `encoder_hidden_states`.
+ out_bias (`bool`, defaults to `True`):
+ Whether to include the bias parameter in `train_q_out`.
+ dropout (`float`, *optional*, defaults to 0.0):
+ The dropout probability to use.
+ """
+
+ def __init__(
+ self,
+ train_kv: bool = True,
+ train_q_out: bool = True,
+ hidden_size: Optional[int] = None,
+ cross_attention_dim: Optional[int] = None,
+ out_bias: bool = True,
+ dropout: float = 0.0,
+ ):
+ super().__init__()
+ self.train_kv = train_kv
+ self.train_q_out = train_q_out
+
+ self.hidden_size = hidden_size
+ self.cross_attention_dim = cross_attention_dim
+
+ # `_custom_diffusion` id for easy serialization and loading.
+ if self.train_kv:
+ self.to_k_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
+ self.to_v_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
+ if self.train_q_out:
+ self.to_q_custom_diffusion = nn.Linear(hidden_size, hidden_size, bias=False)
+ self.to_out_custom_diffusion = nn.ModuleList([])
+ self.to_out_custom_diffusion.append(nn.Linear(hidden_size, hidden_size, bias=out_bias))
+ self.to_out_custom_diffusion.append(nn.Dropout(dropout))
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ batch_size, sequence_length, _ = hidden_states.shape
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+ if self.train_q_out:
+ query = self.to_q_custom_diffusion(hidden_states)
+ else:
+ query = attn.to_q(hidden_states)
+
+ if encoder_hidden_states is None:
+ crossattn = False
+ encoder_hidden_states = hidden_states
+ else:
+ crossattn = True
+ if attn.norm_cross:
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+ if self.train_kv:
+ key = self.to_k_custom_diffusion(encoder_hidden_states.to(self.to_k_custom_diffusion.weight.dtype))
+ value = self.to_v_custom_diffusion(encoder_hidden_states.to(self.to_v_custom_diffusion.weight.dtype))
+ key = key.to(attn.to_q.weight.dtype)
+ value = value.to(attn.to_q.weight.dtype)
+
+ else:
+ key = attn.to_k(encoder_hidden_states)
+ value = attn.to_v(encoder_hidden_states)
+
+ if crossattn:
+ detach = torch.ones_like(key)
+ detach[:, :1, :] = detach[:, :1, :] * 0.0
+ key = detach * key + (1 - detach) * key.detach()
+ value = detach * value + (1 - detach) * value.detach()
+
+ inner_dim = hidden_states.shape[-1]
+
+ head_dim = inner_dim // attn.heads
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ hidden_states = F.scaled_dot_product_attention(
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
+ )
+
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+ hidden_states = hidden_states.to(query.dtype)
+
+ if self.train_q_out:
+ # linear proj
+ hidden_states = self.to_out_custom_diffusion[0](hidden_states)
+ # dropout
+ hidden_states = self.to_out_custom_diffusion[1](hidden_states)
+ else:
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ return hidden_states
+
+
+class SlicedAttnProcessor:
+ r"""
+ Processor for implementing sliced attention.
+
+ Args:
+ slice_size (`int`, *optional*):
+ The number of steps to compute attention. Uses as many slices as `attention_head_dim // slice_size`, and
+ `attention_head_dim` must be a multiple of the `slice_size`.
+ """
+
+ def __init__(self, slice_size: int):
+ self.slice_size = slice_size
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ residual = hidden_states
+
+ input_ndim = hidden_states.ndim
+
+ if input_ndim == 4:
+ batch_size, channel, height, width = hidden_states.shape
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+
+ batch_size, sequence_length, _ = (
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+ )
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+
+ if attn.group_norm is not None:
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+ query = attn.to_q(hidden_states)
+ dim = query.shape[-1]
+ query = attn.head_to_batch_dim(query)
+
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+ elif attn.norm_cross:
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+ key = attn.to_k(encoder_hidden_states)
+ value = attn.to_v(encoder_hidden_states)
+ key = attn.head_to_batch_dim(key)
+ value = attn.head_to_batch_dim(value)
+
+ batch_size_attention, query_tokens, _ = query.shape
+ hidden_states = torch.zeros(
+ (batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype
+ )
+
+ for i in range((batch_size_attention - 1) // self.slice_size + 1):
+ start_idx = i * self.slice_size
+ end_idx = (i + 1) * self.slice_size
+
+ query_slice = query[start_idx:end_idx]
+ key_slice = key[start_idx:end_idx]
+ attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
+
+ attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
+
+ attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
+
+ hidden_states[start_idx:end_idx] = attn_slice
+
+ hidden_states = attn.batch_to_head_dim(hidden_states)
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ if input_ndim == 4:
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
+
+ if attn.residual_connection:
+ hidden_states = hidden_states + residual
+
+ hidden_states = hidden_states / attn.rescale_output_factor
+
+ return hidden_states
+
+
+class SlicedAttnAddedKVProcessor:
+ r"""
+ Processor for implementing sliced attention with extra learnable key and value matrices for the text encoder.
+
+ Args:
+ slice_size (`int`, *optional*):
+ The number of steps to compute attention. Uses as many slices as `attention_head_dim // slice_size`, and
+ `attention_head_dim` must be a multiple of the `slice_size`.
+ """
+
+ def __init__(self, slice_size):
+ self.slice_size = slice_size
+
+ def __call__(
+ self,
+ attn: "Attention",
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ temb: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ residual = hidden_states
+
+ if attn.spatial_norm is not None:
+ hidden_states = attn.spatial_norm(hidden_states, temb)
+
+ hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
+
+ batch_size, sequence_length, _ = hidden_states.shape
+
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+ elif attn.norm_cross:
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+ query = attn.to_q(hidden_states)
+ dim = query.shape[-1]
+ query = attn.head_to_batch_dim(query)
+
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
+
+ encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
+ encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
+
+ if not attn.only_cross_attention:
+ key = attn.to_k(hidden_states)
+ value = attn.to_v(hidden_states)
+ key = attn.head_to_batch_dim(key)
+ value = attn.head_to_batch_dim(value)
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
+ else:
+ key = encoder_hidden_states_key_proj
+ value = encoder_hidden_states_value_proj
+
+ batch_size_attention, query_tokens, _ = query.shape
+ hidden_states = torch.zeros(
+ (batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype
+ )
+
+ for i in range((batch_size_attention - 1) // self.slice_size + 1):
+ start_idx = i * self.slice_size
+ end_idx = (i + 1) * self.slice_size
+
+ query_slice = query[start_idx:end_idx]
+ key_slice = key[start_idx:end_idx]
+ attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
+
+ attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
+
+ attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
+
+ hidden_states[start_idx:end_idx] = attn_slice
+
+ hidden_states = attn.batch_to_head_dim(hidden_states)
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
+ hidden_states = hidden_states + residual
+
+ return hidden_states
+
+
+class SpatialNorm(nn.Module):
+ """
+ Spatially conditioned normalization as defined in https://arxiv.org/abs/2209.09002.
+
+ Args:
+ f_channels (`int`):
+ The number of channels for input to group normalization layer, and output of the spatial norm layer.
+ zq_channels (`int`):
+ The number of channels for the quantized vector as described in the paper.
+ """
+
+ def __init__(
+ self,
+ f_channels: int,
+ zq_channels: int,
+ ):
+ super().__init__()
+ self.norm_layer = nn.GroupNorm(num_channels=f_channels, num_groups=32, eps=1e-6, affine=True)
+ self.conv_y = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)
+ self.conv_b = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)
+
+ def forward(self, f: torch.Tensor, zq: torch.Tensor) -> torch.Tensor:
+ f_size = f.shape[-2:]
+ zq = F.interpolate(zq, size=f_size, mode="nearest")
+ norm_f = self.norm_layer(f)
+ new_f = norm_f * self.conv_y(zq) + self.conv_b(zq)
+ return new_f
+
+
+class IPAdapterAttnProcessor(nn.Module):
+ r"""
+ Attention processor for Multiple IP-Adapters.
+
+ Args:
+ hidden_size (`int`):
+ The hidden size of the attention layer.
+ cross_attention_dim (`int`):
+ The number of channels in the `encoder_hidden_states`.
+ num_tokens (`int`, `Tuple[int]` or `List[int]`, defaults to `(4,)`):
+ The context length of the image features.
+ scale (`float` or List[`float`], defaults to 1.0):
+ the weight scale of image prompt.
+ """
+
+ def __init__(self, hidden_size, cross_attention_dim=None, num_tokens=(4,), scale=1.0):
+ super().__init__()
+
+ self.hidden_size = hidden_size
+ self.cross_attention_dim = cross_attention_dim
+
+ if not isinstance(num_tokens, (tuple, list)):
+ num_tokens = [num_tokens]
+ self.num_tokens = num_tokens
+
+ if not isinstance(scale, list):
+ scale = [scale] * len(num_tokens)
+ if len(scale) != len(num_tokens):
+ raise ValueError("`scale` should be a list of integers with the same length as `num_tokens`.")
+ self.scale = scale
+
+ self.to_k_ip = nn.ModuleList(
+ [nn.Linear(cross_attention_dim, hidden_size, bias=False) for _ in range(len(num_tokens))]
+ )
+ self.to_v_ip = nn.ModuleList(
+ [nn.Linear(cross_attention_dim, hidden_size, bias=False) for _ in range(len(num_tokens))]
+ )
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ temb: Optional[torch.Tensor] = None,
+ scale: float = 1.0,
+ ip_adapter_masks: Optional[torch.Tensor] = None,
+ ):
+ residual = hidden_states
+
+ # separate ip_hidden_states from encoder_hidden_states
+ if encoder_hidden_states is not None:
+ if isinstance(encoder_hidden_states, tuple):
+ encoder_hidden_states, ip_hidden_states = encoder_hidden_states
+ else:
+ deprecation_message = (
+ "You have passed a tensor as `encoder_hidden_states`. This is deprecated and will be removed in a future release."
+ " Please make sure to update your script to pass `encoder_hidden_states` as a tuple to suppress this warning."
+ )
+ deprecate("encoder_hidden_states not a tuple", "1.0.0", deprecation_message, standard_warn=False)
+ end_pos = encoder_hidden_states.shape[1] - self.num_tokens[0]
+ encoder_hidden_states, ip_hidden_states = (
+ encoder_hidden_states[:, :end_pos, :],
+ [encoder_hidden_states[:, end_pos:, :]],
+ )
+
+ if attn.spatial_norm is not None:
+ hidden_states = attn.spatial_norm(hidden_states, temb)
+
+ input_ndim = hidden_states.ndim
+
+ if input_ndim == 4:
+ batch_size, channel, height, width = hidden_states.shape
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+
+ batch_size, sequence_length, _ = (
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+ )
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+
+ if attn.group_norm is not None:
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+ query = attn.to_q(hidden_states)
+
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+ elif attn.norm_cross:
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+ key = attn.to_k(encoder_hidden_states)
+ value = attn.to_v(encoder_hidden_states)
+
+ query = attn.head_to_batch_dim(query)
+ key = attn.head_to_batch_dim(key)
+ value = attn.head_to_batch_dim(value)
+
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
+ hidden_states = torch.bmm(attention_probs, value)
+ hidden_states = attn.batch_to_head_dim(hidden_states)
+
+ if ip_adapter_masks is not None:
+ if not isinstance(ip_adapter_masks, List):
+ # for backward compatibility, we accept `ip_adapter_mask` as a tensor of shape [num_ip_adapter, 1, height, width]
+ ip_adapter_masks = list(ip_adapter_masks.unsqueeze(1))
+ if not (len(ip_adapter_masks) == len(self.scale) == len(ip_hidden_states)):
+ raise ValueError(
+ f"Length of ip_adapter_masks array ({len(ip_adapter_masks)}) must match "
+ f"length of self.scale array ({len(self.scale)}) and number of ip_hidden_states "
+ f"({len(ip_hidden_states)})"
+ )
+ else:
+ for index, (mask, scale, ip_state) in enumerate(zip(ip_adapter_masks, self.scale, ip_hidden_states)):
+ if not isinstance(mask, torch.Tensor) or mask.ndim != 4:
+ raise ValueError(
+ "Each element of the ip_adapter_masks array should be a tensor with shape "
+ "[1, num_images_for_ip_adapter, height, width]."
+ " Please use `IPAdapterMaskProcessor` to preprocess your mask"
+ )
+ if mask.shape[1] != ip_state.shape[1]:
+ raise ValueError(
+ f"Number of masks ({mask.shape[1]}) does not match "
+ f"number of ip images ({ip_state.shape[1]}) at index {index}"
+ )
+ if isinstance(scale, list) and not len(scale) == mask.shape[1]:
+ raise ValueError(
+ f"Number of masks ({mask.shape[1]}) does not match "
+ f"number of scales ({len(scale)}) at index {index}"
+ )
+ else:
+ ip_adapter_masks = [None] * len(self.scale)
+
+ # for ip-adapter
+ for current_ip_hidden_states, scale, to_k_ip, to_v_ip, mask in zip(
+ ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip, ip_adapter_masks
+ ):
+ skip = False
+ if isinstance(scale, list):
+ if all(s == 0 for s in scale):
+ skip = True
+ elif scale == 0:
+ skip = True
+ if not skip:
+ if mask is not None:
+ if not isinstance(scale, list):
+ scale = [scale] * mask.shape[1]
+
+ current_num_images = mask.shape[1]
+ for i in range(current_num_images):
+ ip_key = to_k_ip(current_ip_hidden_states[:, i, :, :])
+ ip_value = to_v_ip(current_ip_hidden_states[:, i, :, :])
+
+ ip_key = attn.head_to_batch_dim(ip_key)
+ ip_value = attn.head_to_batch_dim(ip_value)
+
+ ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
+ _current_ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
+ _current_ip_hidden_states = attn.batch_to_head_dim(_current_ip_hidden_states)
+
+ mask_downsample = IPAdapterMaskProcessor.downsample(
+ mask[:, i, :, :],
+ batch_size,
+ _current_ip_hidden_states.shape[1],
+ _current_ip_hidden_states.shape[2],
+ )
+
+ mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device)
+
+ hidden_states = hidden_states + scale[i] * (_current_ip_hidden_states * mask_downsample)
+ else:
+ ip_key = to_k_ip(current_ip_hidden_states)
+ ip_value = to_v_ip(current_ip_hidden_states)
+
+ ip_key = attn.head_to_batch_dim(ip_key)
+ ip_value = attn.head_to_batch_dim(ip_value)
+
+ ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
+ current_ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
+ current_ip_hidden_states = attn.batch_to_head_dim(current_ip_hidden_states)
+
+ hidden_states = hidden_states + scale * current_ip_hidden_states
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ if input_ndim == 4:
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
+
+ if attn.residual_connection:
+ hidden_states = hidden_states + residual
+
+ hidden_states = hidden_states / attn.rescale_output_factor
+
+ return hidden_states
+
+
+class IPAdapterAttnProcessor2_0(torch.nn.Module):
+ r"""
+ Attention processor for IP-Adapter for PyTorch 2.0.
+
+ Args:
+ hidden_size (`int`):
+ The hidden size of the attention layer.
+ cross_attention_dim (`int`):
+ The number of channels in the `encoder_hidden_states`.
+ num_tokens (`int`, `Tuple[int]` or `List[int]`, defaults to `(4,)`):
+ The context length of the image features.
+ scale (`float` or `List[float]`, defaults to 1.0):
+ the weight scale of image prompt.
+ """
+
+ def __init__(self, hidden_size, cross_attention_dim=None, num_tokens=(4,), scale=1.0):
+ super().__init__()
+
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError(
+ f"{self.__class__.__name__} requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
+ )
+
+ self.hidden_size = hidden_size
+ self.cross_attention_dim = cross_attention_dim
+
+ if not isinstance(num_tokens, (tuple, list)):
+ num_tokens = [num_tokens]
+ self.num_tokens = num_tokens
+
+ if not isinstance(scale, list):
+ scale = [scale] * len(num_tokens)
+ if len(scale) != len(num_tokens):
+ raise ValueError("`scale` should be a list of integers with the same length as `num_tokens`.")
+ self.scale = scale
+
+ self.to_k_ip = nn.ModuleList(
+ [nn.Linear(cross_attention_dim, hidden_size, bias=False) for _ in range(len(num_tokens))]
+ )
+ self.to_v_ip = nn.ModuleList(
+ [nn.Linear(cross_attention_dim, hidden_size, bias=False) for _ in range(len(num_tokens))]
+ )
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ temb: Optional[torch.Tensor] = None,
+ scale: float = 1.0,
+ ip_adapter_masks: Optional[torch.Tensor] = None,
+ ):
+ residual = hidden_states
+
+ # separate ip_hidden_states from encoder_hidden_states
+ if encoder_hidden_states is not None:
+ if isinstance(encoder_hidden_states, tuple):
+ encoder_hidden_states, ip_hidden_states = encoder_hidden_states
+ else:
+ deprecation_message = (
+ "You have passed a tensor as `encoder_hidden_states`. This is deprecated and will be removed in a future release."
+ " Please make sure to update your script to pass `encoder_hidden_states` as a tuple to suppress this warning."
+ )
+ deprecate("encoder_hidden_states not a tuple", "1.0.0", deprecation_message, standard_warn=False)
+ end_pos = encoder_hidden_states.shape[1] - self.num_tokens[0]
+ encoder_hidden_states, ip_hidden_states = (
+ encoder_hidden_states[:, :end_pos, :],
+ [encoder_hidden_states[:, end_pos:, :]],
+ )
+
+ if attn.spatial_norm is not None:
+ hidden_states = attn.spatial_norm(hidden_states, temb)
+
+ input_ndim = hidden_states.ndim
+
+ if input_ndim == 4:
+ batch_size, channel, height, width = hidden_states.shape
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+
+ batch_size, sequence_length, _ = (
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+ )
+
+ if attention_mask is not None:
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
+
+ if attn.group_norm is not None:
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+ query = attn.to_q(hidden_states)
+
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+ elif attn.norm_cross:
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+ key = attn.to_k(encoder_hidden_states)
+ value = attn.to_v(encoder_hidden_states)
+
+ inner_dim = key.shape[-1]
+ head_dim = inner_dim // attn.heads
+
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ hidden_states = F.scaled_dot_product_attention(
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
+ )
+
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+ hidden_states = hidden_states.to(query.dtype)
+
+ if ip_adapter_masks is not None:
+ if not isinstance(ip_adapter_masks, List):
+ ip_adapter_masks = list(ip_adapter_masks.unsqueeze(1))
+ if not (len(ip_adapter_masks) == len(self.scale) == len(ip_hidden_states)):
+ raise ValueError(
+ f"Length of ip_adapter_masks array ({len(ip_adapter_masks)}) must match "
+ f"length of self.scale array ({len(self.scale)}) and number of ip_hidden_states "
+ f"({len(ip_hidden_states)})"
+ )
+ else:
+ for index, (mask, scale, ip_state) in enumerate(zip(ip_adapter_masks, self.scale, ip_hidden_states)):
+ if not isinstance(mask, torch.Tensor) or mask.ndim != 4:
+ raise ValueError(
+ "Each element of the ip_adapter_masks array should be a tensor with shape "
+ "[1, num_images_for_ip_adapter, height, width]."
+ " Please use `IPAdapterMaskProcessor` to preprocess your mask"
+ )
+ if mask.shape[1] != ip_state.shape[1]:
+ raise ValueError(
+ f"Number of masks ({mask.shape[1]}) does not match "
+ f"number of ip images ({ip_state.shape[1]}) at index {index}"
+ )
+ if isinstance(scale, list) and not len(scale) == mask.shape[1]:
+ raise ValueError(
+ f"Number of masks ({mask.shape[1]}) does not match "
+ f"number of scales ({len(scale)}) at index {index}"
+ )
+ else:
+ ip_adapter_masks = [None] * len(self.scale)
+
+ # for ip-adapter
+ for current_ip_hidden_states, scale, to_k_ip, to_v_ip, mask in zip(
+ ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip, ip_adapter_masks
+ ):
+ skip = False
+ if isinstance(scale, list):
+ if all(s == 0 for s in scale):
+ skip = True
+ elif scale == 0:
+ skip = True
+ if not skip:
+ if mask is not None:
+ if not isinstance(scale, list):
+ scale = [scale] * mask.shape[1]
+
+ current_num_images = mask.shape[1]
+ for i in range(current_num_images):
+ ip_key = to_k_ip(current_ip_hidden_states[:, i, :, :])
+ ip_value = to_v_ip(current_ip_hidden_states[:, i, :, :])
+
+ ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ _current_ip_hidden_states = F.scaled_dot_product_attention(
+ query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
+ )
+
+ _current_ip_hidden_states = _current_ip_hidden_states.transpose(1, 2).reshape(
+ batch_size, -1, attn.heads * head_dim
+ )
+ _current_ip_hidden_states = _current_ip_hidden_states.to(query.dtype)
+
+ mask_downsample = IPAdapterMaskProcessor.downsample(
+ mask[:, i, :, :],
+ batch_size,
+ _current_ip_hidden_states.shape[1],
+ _current_ip_hidden_states.shape[2],
+ )
+
+ mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device)
+ hidden_states = hidden_states + scale[i] * (_current_ip_hidden_states * mask_downsample)
+ else:
+ ip_key = to_k_ip(current_ip_hidden_states)
+ ip_value = to_v_ip(current_ip_hidden_states)
+
+ ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ current_ip_hidden_states = F.scaled_dot_product_attention(
+ query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
+ )
+
+ current_ip_hidden_states = current_ip_hidden_states.transpose(1, 2).reshape(
+ batch_size, -1, attn.heads * head_dim
+ )
+ current_ip_hidden_states = current_ip_hidden_states.to(query.dtype)
+
+ hidden_states = hidden_states + scale * current_ip_hidden_states
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ if input_ndim == 4:
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
+
+ if attn.residual_connection:
+ hidden_states = hidden_states + residual
+
+ hidden_states = hidden_states / attn.rescale_output_factor
+
+ return hidden_states
+
+
+class PAGIdentitySelfAttnProcessor2_0:
+ r"""
+ Processor for implementing PAG using scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
+ PAG reference: https://arxiv.org/abs/2403.17377
+ """
+
+ def __init__(self):
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError(
+ "PAGIdentitySelfAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
+ )
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: torch.FloatTensor,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ temb: Optional[torch.FloatTensor] = None,
+ ) -> torch.Tensor:
+ residual = hidden_states
+ if attn.spatial_norm is not None:
+ hidden_states = attn.spatial_norm(hidden_states, temb)
+
+ input_ndim = hidden_states.ndim
+ if input_ndim == 4:
+ batch_size, channel, height, width = hidden_states.shape
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+
+ # chunk
+ hidden_states_org, hidden_states_ptb = hidden_states.chunk(2)
+
+ # original path
+ batch_size, sequence_length, _ = hidden_states_org.shape
+
+ if attention_mask is not None:
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
+
+ if attn.group_norm is not None:
+ hidden_states_org = attn.group_norm(hidden_states_org.transpose(1, 2)).transpose(1, 2)
+
+ query = attn.to_q(hidden_states_org)
+ key = attn.to_k(hidden_states_org)
+ value = attn.to_v(hidden_states_org)
+
+ inner_dim = key.shape[-1]
+ head_dim = inner_dim // attn.heads
+
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ hidden_states_org = F.scaled_dot_product_attention(
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
+ )
+ hidden_states_org = hidden_states_org.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+ hidden_states_org = hidden_states_org.to(query.dtype)
+
+ # linear proj
+ hidden_states_org = attn.to_out[0](hidden_states_org)
+ # dropout
+ hidden_states_org = attn.to_out[1](hidden_states_org)
+
+ if input_ndim == 4:
+ hidden_states_org = hidden_states_org.transpose(-1, -2).reshape(batch_size, channel, height, width)
+
+ # perturbed path (identity attention)
+ batch_size, sequence_length, _ = hidden_states_ptb.shape
+
+ if attn.group_norm is not None:
+ hidden_states_ptb = attn.group_norm(hidden_states_ptb.transpose(1, 2)).transpose(1, 2)
+
+ hidden_states_ptb = attn.to_v(hidden_states_ptb)
+ hidden_states_ptb = hidden_states_ptb.to(query.dtype)
+
+ # linear proj
+ hidden_states_ptb = attn.to_out[0](hidden_states_ptb)
+ # dropout
+ hidden_states_ptb = attn.to_out[1](hidden_states_ptb)
+
+ if input_ndim == 4:
+ hidden_states_ptb = hidden_states_ptb.transpose(-1, -2).reshape(batch_size, channel, height, width)
+
+ # cat
+ hidden_states = torch.cat([hidden_states_org, hidden_states_ptb])
+
+ if attn.residual_connection:
+ hidden_states = hidden_states + residual
+
+ hidden_states = hidden_states / attn.rescale_output_factor
+
+ return hidden_states
+
+
+class PAGCFGIdentitySelfAttnProcessor2_0:
+ r"""
+ Processor for implementing PAG using scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
+ PAG reference: https://arxiv.org/abs/2403.17377
+ """
+
+ def __init__(self):
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError(
+ "PAGCFGIdentitySelfAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
+ )
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: torch.FloatTensor,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ temb: Optional[torch.FloatTensor] = None,
+ ) -> torch.Tensor:
+ residual = hidden_states
+ if attn.spatial_norm is not None:
+ hidden_states = attn.spatial_norm(hidden_states, temb)
+
+ input_ndim = hidden_states.ndim
+ if input_ndim == 4:
+ batch_size, channel, height, width = hidden_states.shape
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+
+ # chunk
+ hidden_states_uncond, hidden_states_org, hidden_states_ptb = hidden_states.chunk(3)
+ hidden_states_org = torch.cat([hidden_states_uncond, hidden_states_org])
+
+ # original path
+ batch_size, sequence_length, _ = hidden_states_org.shape
+
+ if attention_mask is not None:
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
+
+ if attn.group_norm is not None:
+ hidden_states_org = attn.group_norm(hidden_states_org.transpose(1, 2)).transpose(1, 2)
+
+ query = attn.to_q(hidden_states_org)
+ key = attn.to_k(hidden_states_org)
+ value = attn.to_v(hidden_states_org)
+
+ inner_dim = key.shape[-1]
+ head_dim = inner_dim // attn.heads
+
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ hidden_states_org = F.scaled_dot_product_attention(
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
+ )
+
+ hidden_states_org = hidden_states_org.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+ hidden_states_org = hidden_states_org.to(query.dtype)
+
+ # linear proj
+ hidden_states_org = attn.to_out[0](hidden_states_org)
+ # dropout
+ hidden_states_org = attn.to_out[1](hidden_states_org)
+
+ if input_ndim == 4:
+ hidden_states_org = hidden_states_org.transpose(-1, -2).reshape(batch_size, channel, height, width)
+
+ # perturbed path (identity attention)
+ batch_size, sequence_length, _ = hidden_states_ptb.shape
+
+ if attn.group_norm is not None:
+ hidden_states_ptb = attn.group_norm(hidden_states_ptb.transpose(1, 2)).transpose(1, 2)
+
+ value = attn.to_v(hidden_states_ptb)
+ hidden_states_ptb = value
+ hidden_states_ptb = hidden_states_ptb.to(query.dtype)
+
+ # linear proj
+ hidden_states_ptb = attn.to_out[0](hidden_states_ptb)
+ # dropout
+ hidden_states_ptb = attn.to_out[1](hidden_states_ptb)
+
+ if input_ndim == 4:
+ hidden_states_ptb = hidden_states_ptb.transpose(-1, -2).reshape(batch_size, channel, height, width)
+
+ # cat
+ hidden_states = torch.cat([hidden_states_org, hidden_states_ptb])
+
+ if attn.residual_connection:
+ hidden_states = hidden_states + residual
+
+ hidden_states = hidden_states / attn.rescale_output_factor
+
+ return hidden_states
+
+
+class LoRAAttnProcessor:
+ def __init__(self):
+ pass
+
+
+class LoRAAttnProcessor2_0:
+ def __init__(self):
+ pass
+
+
+class LoRAXFormersAttnProcessor:
+ def __init__(self):
+ pass
+
+
+class LoRAAttnAddedKVProcessor:
+ def __init__(self):
+ pass
+
+
+class FluxSingleAttnProcessor2_0(FluxAttnProcessor2_0):
+ r"""
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
+ """
+
+ def __init__(self):
+ deprecation_message = "`FluxSingleAttnProcessor2_0` is deprecated and will be removed in a future version. Please use `FluxAttnProcessor2_0` instead."
+ deprecate("FluxSingleAttnProcessor2_0", "0.32.0", deprecation_message)
+ super().__init__()
+
+
+ADDED_KV_ATTENTION_PROCESSORS = (
+ AttnAddedKVProcessor,
+ SlicedAttnAddedKVProcessor,
+ AttnAddedKVProcessor2_0,
+ XFormersAttnAddedKVProcessor,
+)
+
+CROSS_ATTENTION_PROCESSORS = (
+ AttnProcessor,
+ AttnProcessor2_0,
+ XFormersAttnProcessor,
+ SlicedAttnProcessor,
+ IPAdapterAttnProcessor,
+ IPAdapterAttnProcessor2_0,
+)
+
+AttentionProcessor = Union[
+ AttnProcessor,
+ AttnProcessor2_0,
+ FusedAttnProcessor2_0,
+ XFormersAttnProcessor,
+ SlicedAttnProcessor,
+ AttnAddedKVProcessor,
+ SlicedAttnAddedKVProcessor,
+ AttnAddedKVProcessor2_0,
+ XFormersAttnAddedKVProcessor,
+ CustomDiffusionAttnProcessor,
+ CustomDiffusionXFormersAttnProcessor,
+ CustomDiffusionAttnProcessor2_0,
+ PAGCFGIdentitySelfAttnProcessor2_0,
+ PAGIdentitySelfAttnProcessor2_0,
+ PAGCFGHunyuanAttnProcessor2_0,
+ PAGHunyuanAttnProcessor2_0,
+]
diff --git a/MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/cogvideox_5b/models/embeddings.py b/MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/cogvideox_5b/models/embeddings.py
new file mode 100644
index 0000000000000000000000000000000000000000..0996d97ecf1f511fd5b7481cfa5f007d3762f570
--- /dev/null
+++ b/MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/cogvideox_5b/models/embeddings.py
@@ -0,0 +1,1819 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import math
+from typing import List, Optional, Tuple, Union
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+from diffusers.utils import deprecate
+from .activations import FP32SiLU, get_activation
+from .attention_processor import Attention
+
+
+def get_timestep_embedding(
+ timesteps: torch.Tensor,
+ embedding_dim: int,
+ flip_sin_to_cos: bool = False,
+ downscale_freq_shift: float = 1,
+ scale: float = 1,
+ max_period: int = 10000,
+):
+ """
+ This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
+
+ Args
+ timesteps (torch.Tensor):
+ a 1-D Tensor of N indices, one per batch element. These may be fractional.
+ embedding_dim (int):
+ the dimension of the output.
+ flip_sin_to_cos (bool):
+ Whether the embedding order should be `cos, sin` (if True) or `sin, cos` (if False)
+ downscale_freq_shift (float):
+ Controls the delta between frequencies between dimensions
+ scale (float):
+ Scaling factor applied to the embeddings.
+ max_period (int):
+ Controls the maximum frequency of the embeddings
+ Returns
+ torch.Tensor: an [N x dim] Tensor of positional embeddings.
+ """
+ assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
+
+ half_dim = embedding_dim // 2
+ exponent = -math.log(max_period) * torch.arange(
+ start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
+ )
+ exponent = exponent / (half_dim - downscale_freq_shift)
+
+ emb = torch.exp(exponent)
+ emb = timesteps[:, None].float() * emb[None, :]
+
+ # scale embeddings
+ emb = scale * emb
+
+ # concat sine and cosine embeddings
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
+
+ # flip sine and cosine embeddings
+ if flip_sin_to_cos:
+ emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
+
+ # zero pad
+ if embedding_dim % 2 == 1:
+ emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
+ return emb
+
+
+def get_3d_sincos_pos_embed(
+ embed_dim: int,
+ spatial_size: Union[int, Tuple[int, int]],
+ temporal_size: int,
+ spatial_interpolation_scale: float = 1.0,
+ temporal_interpolation_scale: float = 1.0,
+) -> np.ndarray:
+ r"""
+ Args:
+ embed_dim (`int`):
+ spatial_size (`int` or `Tuple[int, int]`):
+ temporal_size (`int`):
+ spatial_interpolation_scale (`float`, defaults to 1.0):
+ temporal_interpolation_scale (`float`, defaults to 1.0):
+ """
+ if embed_dim % 4 != 0:
+ raise ValueError("`embed_dim` must be divisible by 4")
+ if isinstance(spatial_size, int):
+ spatial_size = (spatial_size, spatial_size)
+
+ embed_dim_spatial = 3 * embed_dim // 4
+ embed_dim_temporal = embed_dim // 4
+
+ # 1. Spatial
+ grid_h = np.arange(spatial_size[1], dtype=np.float32) / spatial_interpolation_scale
+ grid_w = np.arange(spatial_size[0], dtype=np.float32) / spatial_interpolation_scale
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
+ grid = np.stack(grid, axis=0)
+
+ grid = grid.reshape([2, 1, spatial_size[1], spatial_size[0]])
+ pos_embed_spatial = get_2d_sincos_pos_embed_from_grid(embed_dim_spatial, grid)
+
+ # 2. Temporal
+ grid_t = np.arange(temporal_size, dtype=np.float32) / temporal_interpolation_scale
+ pos_embed_temporal = get_1d_sincos_pos_embed_from_grid(embed_dim_temporal, grid_t)
+
+ # 3. Concat
+ pos_embed_spatial = pos_embed_spatial[np.newaxis, :, :]
+ pos_embed_spatial = np.repeat(pos_embed_spatial, temporal_size, axis=0) # [T, H*W, D // 4 * 3]
+
+ pos_embed_temporal = pos_embed_temporal[:, np.newaxis, :]
+ pos_embed_temporal = np.repeat(pos_embed_temporal, spatial_size[0] * spatial_size[1], axis=1) # [T, H*W, D // 4]
+
+ pos_embed = np.concatenate([pos_embed_temporal, pos_embed_spatial], axis=-1) # [T, H*W, D]
+ return pos_embed
+
+
+def get_2d_sincos_pos_embed(
+ embed_dim, grid_size, cls_token=False, extra_tokens=0, interpolation_scale=1.0, base_size=16
+):
+ """
+ grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or
+ [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
+ """
+ if isinstance(grid_size, int):
+ grid_size = (grid_size, grid_size)
+
+ grid_h = np.arange(grid_size[0], dtype=np.float32) / (grid_size[0] / base_size) / interpolation_scale
+ grid_w = np.arange(grid_size[1], dtype=np.float32) / (grid_size[1] / base_size) / interpolation_scale
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
+ grid = np.stack(grid, axis=0)
+
+ grid = grid.reshape([2, 1, grid_size[1], grid_size[0]])
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
+ if cls_token and extra_tokens > 0:
+ pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
+ return pos_embed
+
+
+def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
+ if embed_dim % 2 != 0:
+ raise ValueError("embed_dim must be divisible by 2")
+
+ # use half of dimensions to encode grid_h
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
+
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
+ return emb
+
+
+def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
+ """
+ embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D)
+ """
+ if embed_dim % 2 != 0:
+ raise ValueError("embed_dim must be divisible by 2")
+
+ omega = np.arange(embed_dim // 2, dtype=np.float64)
+ omega /= embed_dim / 2.0
+ omega = 1.0 / 10000**omega # (D/2,)
+
+ pos = pos.reshape(-1) # (M,)
+ out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
+
+ emb_sin = np.sin(out) # (M, D/2)
+ emb_cos = np.cos(out) # (M, D/2)
+
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
+ return emb
+
+
+class PatchEmbed(nn.Module):
+ """2D Image to Patch Embedding with support for SD3 cropping."""
+
+ def __init__(
+ self,
+ height=224,
+ width=224,
+ patch_size=16,
+ in_channels=3,
+ embed_dim=768,
+ layer_norm=False,
+ flatten=True,
+ bias=True,
+ interpolation_scale=1,
+ pos_embed_type="sincos",
+ pos_embed_max_size=None, # For SD3 cropping
+ ):
+ super().__init__()
+
+ num_patches = (height // patch_size) * (width // patch_size)
+ self.flatten = flatten
+ self.layer_norm = layer_norm
+ self.pos_embed_max_size = pos_embed_max_size
+
+ self.proj = nn.Conv2d(
+ in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
+ )
+ if layer_norm:
+ self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6)
+ else:
+ self.norm = None
+
+ self.patch_size = patch_size
+ self.height, self.width = height // patch_size, width // patch_size
+ self.base_size = height // patch_size
+ self.interpolation_scale = interpolation_scale
+
+ # Calculate positional embeddings based on max size or default
+ if pos_embed_max_size:
+ grid_size = pos_embed_max_size
+ else:
+ grid_size = int(num_patches**0.5)
+
+ if pos_embed_type is None:
+ self.pos_embed = None
+ elif pos_embed_type == "sincos":
+ pos_embed = get_2d_sincos_pos_embed(
+ embed_dim, grid_size, base_size=self.base_size, interpolation_scale=self.interpolation_scale
+ )
+ persistent = True if pos_embed_max_size else False
+ self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=persistent)
+ else:
+ raise ValueError(f"Unsupported pos_embed_type: {pos_embed_type}")
+
+ def cropped_pos_embed(self, height, width):
+ """Crops positional embeddings for SD3 compatibility."""
+ if self.pos_embed_max_size is None:
+ raise ValueError("`pos_embed_max_size` must be set for cropping.")
+
+ height = height // self.patch_size
+ width = width // self.patch_size
+ if height > self.pos_embed_max_size:
+ raise ValueError(
+ f"Height ({height}) cannot be greater than `pos_embed_max_size`: {self.pos_embed_max_size}."
+ )
+ if width > self.pos_embed_max_size:
+ raise ValueError(
+ f"Width ({width}) cannot be greater than `pos_embed_max_size`: {self.pos_embed_max_size}."
+ )
+
+ top = (self.pos_embed_max_size - height) // 2
+ left = (self.pos_embed_max_size - width) // 2
+ spatial_pos_embed = self.pos_embed.reshape(1, self.pos_embed_max_size, self.pos_embed_max_size, -1)
+ spatial_pos_embed = spatial_pos_embed[:, top : top + height, left : left + width, :]
+ spatial_pos_embed = spatial_pos_embed.reshape(1, -1, spatial_pos_embed.shape[-1])
+ return spatial_pos_embed
+
+ def forward(self, latent):
+ if self.pos_embed_max_size is not None:
+ height, width = latent.shape[-2:]
+ else:
+ height, width = latent.shape[-2] // self.patch_size, latent.shape[-1] // self.patch_size
+
+ latent = self.proj(latent)
+ if self.flatten:
+ latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC
+ if self.layer_norm:
+ latent = self.norm(latent)
+ if self.pos_embed is None:
+ return latent.to(latent.dtype)
+ # Interpolate or crop positional embeddings as needed
+ if self.pos_embed_max_size:
+ pos_embed = self.cropped_pos_embed(height, width)
+ else:
+ if self.height != height or self.width != width:
+ pos_embed = get_2d_sincos_pos_embed(
+ embed_dim=self.pos_embed.shape[-1],
+ grid_size=(height, width),
+ base_size=self.base_size,
+ interpolation_scale=self.interpolation_scale,
+ )
+ pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0).to(latent.device)
+ else:
+ pos_embed = self.pos_embed
+
+ return (latent + pos_embed).to(latent.dtype)
+
+
+class LuminaPatchEmbed(nn.Module):
+ """2D Image to Patch Embedding with support for Lumina-T2X"""
+
+ def __init__(self, patch_size=2, in_channels=4, embed_dim=768, bias=True):
+ super().__init__()
+ self.patch_size = patch_size
+ self.proj = nn.Linear(
+ in_features=patch_size * patch_size * in_channels,
+ out_features=embed_dim,
+ bias=bias,
+ )
+
+ def forward(self, x, freqs_cis):
+ """
+ Patchifies and embeds the input tensor(s).
+
+ Args:
+ x (List[torch.Tensor] | torch.Tensor): The input tensor(s) to be patchified and embedded.
+
+ Returns:
+ Tuple[torch.Tensor, torch.Tensor, List[Tuple[int, int]], torch.Tensor]: A tuple containing the patchified
+ and embedded tensor(s), the mask indicating the valid patches, the original image size(s), and the
+ frequency tensor(s).
+ """
+ freqs_cis = freqs_cis.to(x[0].device)
+ patch_height = patch_width = self.patch_size
+ batch_size, channel, height, width = x.size()
+ height_tokens, width_tokens = height // patch_height, width // patch_width
+
+ x = x.view(batch_size, channel, height_tokens, patch_height, width_tokens, patch_width).permute(
+ 0, 2, 4, 1, 3, 5
+ )
+ x = x.flatten(3)
+ x = self.proj(x)
+ x = x.flatten(1, 2)
+
+ mask = torch.ones(x.shape[0], x.shape[1], dtype=torch.int32, device=x.device)
+
+ return (
+ x,
+ mask,
+ [(height, width)] * batch_size,
+ freqs_cis[:height_tokens, :width_tokens].flatten(0, 1).unsqueeze(0),
+ )
+
+
+class CogVideoXPatchEmbed(nn.Module):
+ def __init__(
+ self,
+ patch_size: int = 2,
+ in_channels: int = 16,
+ embed_dim: int = 1920,
+ text_embed_dim: int = 4096,
+ bias: bool = True,
+ sample_width: int = 90,
+ sample_height: int = 60,
+ sample_frames: int = 49,
+ temporal_compression_ratio: int = 4,
+ max_text_seq_length: int = 226,
+ spatial_interpolation_scale: float = 1.875,
+ temporal_interpolation_scale: float = 1.0,
+ use_positional_embeddings: bool = True,
+ use_learned_positional_embeddings: bool = True,
+ ) -> None:
+ super().__init__()
+
+ self.patch_size = patch_size
+ self.embed_dim = embed_dim
+ self.sample_height = sample_height
+ self.sample_width = sample_width
+ self.sample_frames = sample_frames
+ self.temporal_compression_ratio = temporal_compression_ratio
+ self.max_text_seq_length = max_text_seq_length
+ self.spatial_interpolation_scale = spatial_interpolation_scale
+ self.temporal_interpolation_scale = temporal_interpolation_scale
+ self.use_positional_embeddings = use_positional_embeddings
+ self.use_learned_positional_embeddings = use_learned_positional_embeddings
+
+ self.proj = nn.Conv2d(
+ in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
+ )
+ self.text_proj = nn.Linear(text_embed_dim, embed_dim)
+
+ if use_positional_embeddings or use_learned_positional_embeddings:
+ persistent = use_learned_positional_embeddings
+ pos_embedding = self._get_positional_embeddings(sample_height, sample_width, sample_frames)
+ self.register_buffer("pos_embedding", pos_embedding, persistent=persistent)
+
+ def _get_positional_embeddings(self, sample_height: int, sample_width: int, sample_frames: int) -> torch.Tensor:
+ post_patch_height = sample_height // self.patch_size
+ post_patch_width = sample_width // self.patch_size
+ post_time_compression_frames = (sample_frames - 1) // self.temporal_compression_ratio + 1
+ num_patches = post_patch_height * post_patch_width * post_time_compression_frames
+
+ pos_embedding = get_3d_sincos_pos_embed(
+ self.embed_dim,
+ (post_patch_width, post_patch_height),
+ post_time_compression_frames,
+ self.spatial_interpolation_scale,
+ self.temporal_interpolation_scale,
+ )
+ pos_embedding = torch.from_numpy(pos_embedding).flatten(0, 1)
+ joint_pos_embedding = torch.zeros(
+ 1, self.max_text_seq_length + num_patches, self.embed_dim, requires_grad=False
+ )
+ joint_pos_embedding.data[:, self.max_text_seq_length :].copy_(pos_embedding)
+
+ return joint_pos_embedding
+
+ def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor):
+ r"""
+ Args:
+ text_embeds (`torch.Tensor`):
+ Input text embeddings. Expected shape: (batch_size, seq_length, embedding_dim).
+ image_embeds (`torch.Tensor`):
+ Input image embeddings. Expected shape: (batch_size, num_frames, channels, height, width).
+ """
+ text_embeds = self.text_proj(text_embeds)
+
+ batch, num_frames, channels, height, width = image_embeds.shape
+ image_embeds = image_embeds.reshape(-1, channels, height, width)
+ image_embeds = self.proj(image_embeds)
+ image_embeds = image_embeds.view(batch, num_frames, *image_embeds.shape[1:])
+ image_embeds = image_embeds.flatten(3).transpose(2, 3) # [batch, num_frames, height x width, channels]
+ image_embeds = image_embeds.flatten(1, 2) # [batch, num_frames x height x width, channels]
+
+ embeds = torch.cat(
+ [text_embeds, image_embeds], dim=1
+ ).contiguous() # [batch, seq_length + num_frames x height x width, channels]
+
+ if self.use_positional_embeddings or self.use_learned_positional_embeddings:
+ if self.use_learned_positional_embeddings and (self.sample_width != width or self.sample_height != height):
+ raise ValueError(
+ "It is currently not possible to generate videos at a different resolution that the defaults. This should only be the case with 'THUDM/CogVideoX-5b-I2V'."
+ "If you think this is incorrect, please open an issue at https://github.com/huggingface/diffusers/issues."
+ )
+
+ pre_time_compression_frames = (num_frames - 1) * self.temporal_compression_ratio + 1
+
+ if (
+ self.sample_height != height
+ or self.sample_width != width
+ or self.sample_frames != pre_time_compression_frames
+ ):
+ pos_embedding = self._get_positional_embeddings(height, width, pre_time_compression_frames)
+ pos_embedding = pos_embedding.to(embeds.device, dtype=embeds.dtype)
+ else:
+ pos_embedding = self.pos_embedding
+
+ embeds = embeds + pos_embedding
+
+ return embeds
+
+
+class CogView3PlusPatchEmbed(nn.Module):
+ def __init__(
+ self,
+ in_channels: int = 16,
+ hidden_size: int = 2560,
+ patch_size: int = 2,
+ text_hidden_size: int = 4096,
+ pos_embed_max_size: int = 128,
+ ):
+ super().__init__()
+ self.in_channels = in_channels
+ self.hidden_size = hidden_size
+ self.patch_size = patch_size
+ self.text_hidden_size = text_hidden_size
+ self.pos_embed_max_size = pos_embed_max_size
+ # Linear projection for image patches
+ self.proj = nn.Linear(in_channels * patch_size**2, hidden_size)
+
+ # Linear projection for text embeddings
+ self.text_proj = nn.Linear(text_hidden_size, hidden_size)
+
+ pos_embed = get_2d_sincos_pos_embed(hidden_size, pos_embed_max_size, base_size=pos_embed_max_size)
+ pos_embed = pos_embed.reshape(pos_embed_max_size, pos_embed_max_size, hidden_size)
+ self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float(), persistent=False)
+
+ def forward(self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor) -> torch.Tensor:
+ batch_size, channel, height, width = hidden_states.shape
+
+ if height % self.patch_size != 0 or width % self.patch_size != 0:
+ raise ValueError("Height and width must be divisible by patch size")
+
+ height = height // self.patch_size
+ width = width // self.patch_size
+ hidden_states = hidden_states.view(batch_size, channel, height, self.patch_size, width, self.patch_size)
+ hidden_states = hidden_states.permute(0, 2, 4, 1, 3, 5).contiguous()
+ hidden_states = hidden_states.view(batch_size, height * width, channel * self.patch_size * self.patch_size)
+
+ # Project the patches
+ hidden_states = self.proj(hidden_states)
+ encoder_hidden_states = self.text_proj(encoder_hidden_states)
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
+
+ # Calculate text_length
+ text_length = encoder_hidden_states.shape[1]
+
+ image_pos_embed = self.pos_embed[:height, :width].reshape(height * width, -1)
+ text_pos_embed = torch.zeros(
+ (text_length, self.hidden_size), dtype=image_pos_embed.dtype, device=image_pos_embed.device
+ )
+ pos_embed = torch.cat([text_pos_embed, image_pos_embed], dim=0)[None, ...]
+
+ return (hidden_states + pos_embed).to(hidden_states.dtype)
+
+
+def get_3d_rotary_pos_embed(
+ embed_dim, crops_coords, grid_size, temporal_size, theta: int = 10000, use_real: bool = True
+) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
+ """
+ RoPE for video tokens with 3D structure.
+
+ Args:
+ embed_dim: (`int`):
+ The embedding dimension size, corresponding to hidden_size_head.
+ crops_coords (`Tuple[int]`):
+ The top-left and bottom-right coordinates of the crop.
+ grid_size (`Tuple[int]`):
+ The grid size of the spatial positional embedding (height, width).
+ temporal_size (`int`):
+ The size of the temporal dimension.
+ theta (`float`):
+ Scaling factor for frequency computation.
+
+ Returns:
+ `torch.Tensor`: positional embedding with shape `(temporal_size * grid_size[0] * grid_size[1], embed_dim/2)`.
+ """
+ if use_real is not True:
+ raise ValueError(" `use_real = False` is not currently supported for get_3d_rotary_pos_embed")
+ start, stop = crops_coords
+ grid_size_h, grid_size_w = grid_size
+ grid_h = np.linspace(start[0], stop[0], grid_size_h, endpoint=False, dtype=np.float32)
+ grid_w = np.linspace(start[1], stop[1], grid_size_w, endpoint=False, dtype=np.float32)
+ grid_t = np.linspace(0, temporal_size, temporal_size, endpoint=False, dtype=np.float32)
+
+ # Compute dimensions for each axis
+ dim_t = embed_dim // 4
+ dim_h = embed_dim // 8 * 3
+ dim_w = embed_dim // 8 * 3
+
+ # Temporal frequencies
+ freqs_t = get_1d_rotary_pos_embed(dim_t, grid_t, use_real=True)
+ # Spatial frequencies for height and width
+ freqs_h = get_1d_rotary_pos_embed(dim_h, grid_h, use_real=True)
+ freqs_w = get_1d_rotary_pos_embed(dim_w, grid_w, use_real=True)
+
+ # BroadCast and concatenate temporal and spaial frequencie (height and width) into a 3d tensor
+ def combine_time_height_width(freqs_t, freqs_h, freqs_w):
+ freqs_t = freqs_t[:, None, None, :].expand(
+ -1, grid_size_h, grid_size_w, -1
+ ) # temporal_size, grid_size_h, grid_size_w, dim_t
+ freqs_h = freqs_h[None, :, None, :].expand(
+ temporal_size, -1, grid_size_w, -1
+ ) # temporal_size, grid_size_h, grid_size_2, dim_h
+ freqs_w = freqs_w[None, None, :, :].expand(
+ temporal_size, grid_size_h, -1, -1
+ ) # temporal_size, grid_size_h, grid_size_2, dim_w
+
+ freqs = torch.cat(
+ [freqs_t, freqs_h, freqs_w], dim=-1
+ ) # temporal_size, grid_size_h, grid_size_w, (dim_t + dim_h + dim_w)
+ freqs = freqs.view(
+ temporal_size * grid_size_h * grid_size_w, -1
+ ) # (temporal_size * grid_size_h * grid_size_w), (dim_t + dim_h + dim_w)
+ return freqs
+
+ t_cos, t_sin = freqs_t # both t_cos and t_sin has shape: temporal_size, dim_t
+ h_cos, h_sin = freqs_h # both h_cos and h_sin has shape: grid_size_h, dim_h
+ w_cos, w_sin = freqs_w # both w_cos and w_sin has shape: grid_size_w, dim_w
+ cos = combine_time_height_width(t_cos, h_cos, w_cos)
+ sin = combine_time_height_width(t_sin, h_sin, w_sin)
+ return cos, sin
+
+
+def get_2d_rotary_pos_embed(embed_dim, crops_coords, grid_size, use_real=True):
+ """
+ RoPE for image tokens with 2d structure.
+
+ Args:
+ embed_dim: (`int`):
+ The embedding dimension size
+ crops_coords (`Tuple[int]`)
+ The top-left and bottom-right coordinates of the crop.
+ grid_size (`Tuple[int]`):
+ The grid size of the positional embedding.
+ use_real (`bool`):
+ If True, return real part and imaginary part separately. Otherwise, return complex numbers.
+
+ Returns:
+ `torch.Tensor`: positional embedding with shape `( grid_size * grid_size, embed_dim/2)`.
+ """
+ start, stop = crops_coords
+ grid_h = np.linspace(start[0], stop[0], grid_size[0], endpoint=False, dtype=np.float32)
+ grid_w = np.linspace(start[1], stop[1], grid_size[1], endpoint=False, dtype=np.float32)
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
+ grid = np.stack(grid, axis=0) # [2, W, H]
+
+ grid = grid.reshape([2, 1, *grid.shape[1:]])
+ pos_embed = get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=use_real)
+ return pos_embed
+
+
+def get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=False):
+ assert embed_dim % 4 == 0
+
+ # use half of dimensions to encode grid_h
+ emb_h = get_1d_rotary_pos_embed(
+ embed_dim // 2, grid[0].reshape(-1), use_real=use_real
+ ) # (H*W, D/2) if use_real else (H*W, D/4)
+ emb_w = get_1d_rotary_pos_embed(
+ embed_dim // 2, grid[1].reshape(-1), use_real=use_real
+ ) # (H*W, D/2) if use_real else (H*W, D/4)
+
+ if use_real:
+ cos = torch.cat([emb_h[0], emb_w[0]], dim=1) # (H*W, D)
+ sin = torch.cat([emb_h[1], emb_w[1]], dim=1) # (H*W, D)
+ return cos, sin
+ else:
+ emb = torch.cat([emb_h, emb_w], dim=1) # (H*W, D/2)
+ return emb
+
+
+def get_2d_rotary_pos_embed_lumina(embed_dim, len_h, len_w, linear_factor=1.0, ntk_factor=1.0):
+ assert embed_dim % 4 == 0
+
+ emb_h = get_1d_rotary_pos_embed(
+ embed_dim // 2, len_h, linear_factor=linear_factor, ntk_factor=ntk_factor
+ ) # (H, D/4)
+ emb_w = get_1d_rotary_pos_embed(
+ embed_dim // 2, len_w, linear_factor=linear_factor, ntk_factor=ntk_factor
+ ) # (W, D/4)
+ emb_h = emb_h.view(len_h, 1, embed_dim // 4, 1).repeat(1, len_w, 1, 1) # (H, W, D/4, 1)
+ emb_w = emb_w.view(1, len_w, embed_dim // 4, 1).repeat(len_h, 1, 1, 1) # (H, W, D/4, 1)
+
+ emb = torch.cat([emb_h, emb_w], dim=-1).flatten(2) # (H, W, D/2)
+ return emb
+
+
+def get_1d_rotary_pos_embed(
+ dim: int,
+ pos: Union[np.ndarray, int],
+ theta: float = 10000.0,
+ use_real=False,
+ linear_factor=1.0,
+ ntk_factor=1.0,
+ repeat_interleave_real=True,
+ freqs_dtype=torch.float32, # torch.float32, torch.float64 (flux)
+):
+ """
+ Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
+
+ This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' and the end
+ index 'end'. The 'theta' parameter scales the frequencies. The returned tensor contains complex values in complex64
+ data type.
+
+ Args:
+ dim (`int`): Dimension of the frequency tensor.
+ pos (`np.ndarray` or `int`): Position indices for the frequency tensor. [S] or scalar
+ theta (`float`, *optional*, defaults to 10000.0):
+ Scaling factor for frequency computation. Defaults to 10000.0.
+ use_real (`bool`, *optional*):
+ If True, return real part and imaginary part separately. Otherwise, return complex numbers.
+ linear_factor (`float`, *optional*, defaults to 1.0):
+ Scaling factor for the context extrapolation. Defaults to 1.0.
+ ntk_factor (`float`, *optional*, defaults to 1.0):
+ Scaling factor for the NTK-Aware RoPE. Defaults to 1.0.
+ repeat_interleave_real (`bool`, *optional*, defaults to `True`):
+ If `True` and `use_real`, real part and imaginary part are each interleaved with themselves to reach `dim`.
+ Otherwise, they are concateanted with themselves.
+ freqs_dtype (`torch.float32` or `torch.float64`, *optional*, defaults to `torch.float32`):
+ the dtype of the frequency tensor.
+ Returns:
+ `torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2]
+ """
+ assert dim % 2 == 0
+
+ if isinstance(pos, int):
+ pos = torch.arange(pos)
+ if isinstance(pos, np.ndarray):
+ pos = torch.from_numpy(pos) # type: ignore # [S]
+
+ theta = theta * ntk_factor
+ freqs = (
+ 1.0
+ / (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype, device=pos.device)[: (dim // 2)] / dim))
+ / linear_factor
+ ) # [D/2]
+ freqs = torch.outer(pos, freqs) # type: ignore # [S, D/2]
+ if use_real and repeat_interleave_real:
+ # flux, hunyuan-dit, cogvideox
+ freqs_cos = freqs.cos().repeat_interleave(2, dim=1).float() # [S, D]
+ freqs_sin = freqs.sin().repeat_interleave(2, dim=1).float() # [S, D]
+ return freqs_cos, freqs_sin
+ elif use_real:
+ # stable audio
+ freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1).float() # [S, D]
+ freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1).float() # [S, D]
+ return freqs_cos, freqs_sin
+ else:
+ # lumina
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2]
+ return freqs_cis
+
+
+def apply_rotary_emb(
+ x: torch.Tensor,
+ freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
+ use_real: bool = True,
+ use_real_unbind_dim: int = -1,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
+ to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
+ reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting
+ tensors contain rotary embeddings and are returned as real tensors.
+
+ Args:
+ x (`torch.Tensor`):
+ Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply
+ freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
+
+ Returns:
+ Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
+ """
+ if use_real:
+ cos, sin = freqs_cis # [S, D]
+ cos = cos[None, None]
+ sin = sin[None, None]
+ # cos = cos[None,:, None,:]
+ # sin = sin[None,:, None,:]
+ cos, sin = cos.to(x.device), sin.to(x.device)
+
+ if use_real_unbind_dim == -1:
+ # Used for flux, cogvideox, hunyuan-dit
+ x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
+ x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
+ elif use_real_unbind_dim == -2:
+ # Used for Stable Audio
+ x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, S, H, D//2]
+ x_rotated = torch.cat([-x_imag, x_real], dim=-1)
+ else:
+ raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.")
+ # print(f'x.shape: {x.shape}, cos.shape: {cos.shape}')
+ out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
+
+ return out
+ else:
+ # used for lumina
+ x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
+ freqs_cis = freqs_cis.unsqueeze(2)
+ x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3)
+
+ return x_out.type_as(x)
+
+
+class FluxPosEmbed(nn.Module):
+ # modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11
+ def __init__(self, theta: int, axes_dim: List[int]):
+ super().__init__()
+ self.theta = theta
+ self.axes_dim = axes_dim
+
+ def forward(self, ids: torch.Tensor) -> torch.Tensor:
+ n_axes = ids.shape[-1]
+ cos_out = []
+ sin_out = []
+ pos = ids.float()
+ is_mps = ids.device.type == "mps"
+ freqs_dtype = torch.float32 if is_mps else torch.float64
+ for i in range(n_axes):
+ cos, sin = get_1d_rotary_pos_embed(
+ self.axes_dim[i], pos[:, i], repeat_interleave_real=True, use_real=True, freqs_dtype=freqs_dtype
+ )
+ cos_out.append(cos)
+ sin_out.append(sin)
+ freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device)
+ freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device)
+ return freqs_cos, freqs_sin
+
+
+class TimestepEmbedding(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ time_embed_dim: int,
+ act_fn: str = "silu",
+ out_dim: int = None,
+ post_act_fn: Optional[str] = None,
+ cond_proj_dim=None,
+ sample_proj_bias=True,
+ ):
+ super().__init__()
+
+ self.linear_1 = nn.Linear(in_channels, time_embed_dim, sample_proj_bias)
+
+ if cond_proj_dim is not None:
+ self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False)
+ else:
+ self.cond_proj = None
+
+ self.act = get_activation(act_fn)
+
+ if out_dim is not None:
+ time_embed_dim_out = out_dim
+ else:
+ time_embed_dim_out = time_embed_dim
+ self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out, sample_proj_bias)
+
+ if post_act_fn is None:
+ self.post_act = None
+ else:
+ self.post_act = get_activation(post_act_fn)
+
+ def forward(self, sample, condition=None):
+ if condition is not None:
+ sample = sample + self.cond_proj(condition)
+ sample = self.linear_1(sample)
+
+ if self.act is not None:
+ sample = self.act(sample)
+
+ sample = self.linear_2(sample)
+
+ if self.post_act is not None:
+ sample = self.post_act(sample)
+ return sample
+
+
+class Timesteps(nn.Module):
+ def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, scale: int = 1):
+ super().__init__()
+ self.num_channels = num_channels
+ self.flip_sin_to_cos = flip_sin_to_cos
+ self.downscale_freq_shift = downscale_freq_shift
+ self.scale = scale
+
+ def forward(self, timesteps):
+ t_emb = get_timestep_embedding(
+ timesteps,
+ self.num_channels,
+ flip_sin_to_cos=self.flip_sin_to_cos,
+ downscale_freq_shift=self.downscale_freq_shift,
+ scale=self.scale,
+ )
+ return t_emb
+
+
+class GaussianFourierProjection(nn.Module):
+ """Gaussian Fourier embeddings for noise levels."""
+
+ def __init__(
+ self, embedding_size: int = 256, scale: float = 1.0, set_W_to_weight=True, log=True, flip_sin_to_cos=False
+ ):
+ super().__init__()
+ self.weight = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)
+ self.log = log
+ self.flip_sin_to_cos = flip_sin_to_cos
+
+ if set_W_to_weight:
+ # to delete later
+ del self.weight
+ self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)
+ self.weight = self.W
+ del self.W
+
+ def forward(self, x):
+ if self.log:
+ x = torch.log(x)
+
+ x_proj = x[:, None] * self.weight[None, :] * 2 * np.pi
+
+ if self.flip_sin_to_cos:
+ out = torch.cat([torch.cos(x_proj), torch.sin(x_proj)], dim=-1)
+ else:
+ out = torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)
+ return out
+
+
+class SinusoidalPositionalEmbedding(nn.Module):
+ """Apply positional information to a sequence of embeddings.
+
+ Takes in a sequence of embeddings with shape (batch_size, seq_length, embed_dim) and adds positional embeddings to
+ them
+
+ Args:
+ embed_dim: (int): Dimension of the positional embedding.
+ max_seq_length: Maximum sequence length to apply positional embeddings
+
+ """
+
+ def __init__(self, embed_dim: int, max_seq_length: int = 32):
+ super().__init__()
+ position = torch.arange(max_seq_length).unsqueeze(1)
+ div_term = torch.exp(torch.arange(0, embed_dim, 2) * (-math.log(10000.0) / embed_dim))
+ pe = torch.zeros(1, max_seq_length, embed_dim)
+ pe[0, :, 0::2] = torch.sin(position * div_term)
+ pe[0, :, 1::2] = torch.cos(position * div_term)
+ self.register_buffer("pe", pe)
+
+ def forward(self, x):
+ _, seq_length, _ = x.shape
+ x = x + self.pe[:, :seq_length]
+ return x
+
+
+class ImagePositionalEmbeddings(nn.Module):
+ """
+ Converts latent image classes into vector embeddings. Sums the vector embeddings with positional embeddings for the
+ height and width of the latent space.
+
+ For more details, see figure 10 of the dall-e paper: https://arxiv.org/abs/2102.12092
+
+ For VQ-diffusion:
+
+ Output vector embeddings are used as input for the transformer.
+
+ Note that the vector embeddings for the transformer are different than the vector embeddings from the VQVAE.
+
+ Args:
+ num_embed (`int`):
+ Number of embeddings for the latent pixels embeddings.
+ height (`int`):
+ Height of the latent image i.e. the number of height embeddings.
+ width (`int`):
+ Width of the latent image i.e. the number of width embeddings.
+ embed_dim (`int`):
+ Dimension of the produced vector embeddings. Used for the latent pixel, height, and width embeddings.
+ """
+
+ def __init__(
+ self,
+ num_embed: int,
+ height: int,
+ width: int,
+ embed_dim: int,
+ ):
+ super().__init__()
+
+ self.height = height
+ self.width = width
+ self.num_embed = num_embed
+ self.embed_dim = embed_dim
+
+ self.emb = nn.Embedding(self.num_embed, embed_dim)
+ self.height_emb = nn.Embedding(self.height, embed_dim)
+ self.width_emb = nn.Embedding(self.width, embed_dim)
+
+ def forward(self, index):
+ emb = self.emb(index)
+
+ height_emb = self.height_emb(torch.arange(self.height, device=index.device).view(1, self.height))
+
+ # 1 x H x D -> 1 x H x 1 x D
+ height_emb = height_emb.unsqueeze(2)
+
+ width_emb = self.width_emb(torch.arange(self.width, device=index.device).view(1, self.width))
+
+ # 1 x W x D -> 1 x 1 x W x D
+ width_emb = width_emb.unsqueeze(1)
+
+ pos_emb = height_emb + width_emb
+
+ # 1 x H x W x D -> 1 x L xD
+ pos_emb = pos_emb.view(1, self.height * self.width, -1)
+
+ emb = emb + pos_emb[:, : emb.shape[1], :]
+
+ return emb
+
+
+class LabelEmbedding(nn.Module):
+ """
+ Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
+
+ Args:
+ num_classes (`int`): The number of classes.
+ hidden_size (`int`): The size of the vector embeddings.
+ dropout_prob (`float`): The probability of dropping a label.
+ """
+
+ def __init__(self, num_classes, hidden_size, dropout_prob):
+ super().__init__()
+ use_cfg_embedding = dropout_prob > 0
+ self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size)
+ self.num_classes = num_classes
+ self.dropout_prob = dropout_prob
+
+ def token_drop(self, labels, force_drop_ids=None):
+ """
+ Drops labels to enable classifier-free guidance.
+ """
+ if force_drop_ids is None:
+ drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob
+ else:
+ drop_ids = torch.tensor(force_drop_ids == 1)
+ labels = torch.where(drop_ids, self.num_classes, labels)
+ return labels
+
+ def forward(self, labels: torch.LongTensor, force_drop_ids=None):
+ use_dropout = self.dropout_prob > 0
+ if (self.training and use_dropout) or (force_drop_ids is not None):
+ labels = self.token_drop(labels, force_drop_ids)
+ embeddings = self.embedding_table(labels)
+ return embeddings
+
+
+class TextImageProjection(nn.Module):
+ def __init__(
+ self,
+ text_embed_dim: int = 1024,
+ image_embed_dim: int = 768,
+ cross_attention_dim: int = 768,
+ num_image_text_embeds: int = 10,
+ ):
+ super().__init__()
+
+ self.num_image_text_embeds = num_image_text_embeds
+ self.image_embeds = nn.Linear(image_embed_dim, self.num_image_text_embeds * cross_attention_dim)
+ self.text_proj = nn.Linear(text_embed_dim, cross_attention_dim)
+
+ def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor):
+ batch_size = text_embeds.shape[0]
+
+ # image
+ image_text_embeds = self.image_embeds(image_embeds)
+ image_text_embeds = image_text_embeds.reshape(batch_size, self.num_image_text_embeds, -1)
+
+ # text
+ text_embeds = self.text_proj(text_embeds)
+
+ return torch.cat([image_text_embeds, text_embeds], dim=1)
+
+
+class ImageProjection(nn.Module):
+ def __init__(
+ self,
+ image_embed_dim: int = 768,
+ cross_attention_dim: int = 768,
+ num_image_text_embeds: int = 32,
+ ):
+ super().__init__()
+
+ self.num_image_text_embeds = num_image_text_embeds
+ self.image_embeds = nn.Linear(image_embed_dim, self.num_image_text_embeds * cross_attention_dim)
+ self.norm = nn.LayerNorm(cross_attention_dim)
+
+ def forward(self, image_embeds: torch.Tensor):
+ batch_size = image_embeds.shape[0]
+
+ # image
+ image_embeds = self.image_embeds(image_embeds)
+ image_embeds = image_embeds.reshape(batch_size, self.num_image_text_embeds, -1)
+ image_embeds = self.norm(image_embeds)
+ return image_embeds
+
+
+class IPAdapterFullImageProjection(nn.Module):
+ def __init__(self, image_embed_dim=1024, cross_attention_dim=1024):
+ super().__init__()
+ from .attention import FeedForward
+
+ self.ff = FeedForward(image_embed_dim, cross_attention_dim, mult=1, activation_fn="gelu")
+ self.norm = nn.LayerNorm(cross_attention_dim)
+
+ def forward(self, image_embeds: torch.Tensor):
+ return self.norm(self.ff(image_embeds))
+
+
+class IPAdapterFaceIDImageProjection(nn.Module):
+ def __init__(self, image_embed_dim=1024, cross_attention_dim=1024, mult=1, num_tokens=1):
+ super().__init__()
+ from .attention import FeedForward
+
+ self.num_tokens = num_tokens
+ self.cross_attention_dim = cross_attention_dim
+ self.ff = FeedForward(image_embed_dim, cross_attention_dim * num_tokens, mult=mult, activation_fn="gelu")
+ self.norm = nn.LayerNorm(cross_attention_dim)
+
+ def forward(self, image_embeds: torch.Tensor):
+ x = self.ff(image_embeds)
+ x = x.reshape(-1, self.num_tokens, self.cross_attention_dim)
+ return self.norm(x)
+
+
+class CombinedTimestepLabelEmbeddings(nn.Module):
+ def __init__(self, num_classes, embedding_dim, class_dropout_prob=0.1):
+ super().__init__()
+
+ self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=1)
+ self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
+ self.class_embedder = LabelEmbedding(num_classes, embedding_dim, class_dropout_prob)
+
+ def forward(self, timestep, class_labels, hidden_dtype=None):
+ timesteps_proj = self.time_proj(timestep)
+ timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D)
+
+ class_labels = self.class_embedder(class_labels) # (N, D)
+
+ conditioning = timesteps_emb + class_labels # (N, D)
+
+ return conditioning
+
+
+class CombinedTimestepTextProjEmbeddings(nn.Module):
+ def __init__(self, embedding_dim, pooled_projection_dim):
+ super().__init__()
+
+ self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
+ self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
+ self.text_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu")
+
+ def forward(self, timestep, pooled_projection):
+ timesteps_proj = self.time_proj(timestep)
+ timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype)) # (N, D)
+
+ pooled_projections = self.text_embedder(pooled_projection)
+
+ conditioning = timesteps_emb + pooled_projections
+
+ return conditioning
+
+
+class CombinedTimestepGuidanceTextProjEmbeddings(nn.Module):
+ def __init__(self, embedding_dim, pooled_projection_dim):
+ super().__init__()
+
+ self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
+ self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
+ self.guidance_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
+ self.text_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu")
+
+ def forward(self, timestep, guidance, pooled_projection):
+ timesteps_proj = self.time_proj(timestep)
+ timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype)) # (N, D)
+
+ guidance_proj = self.time_proj(guidance)
+ guidance_emb = self.guidance_embedder(guidance_proj.to(dtype=pooled_projection.dtype)) # (N, D)
+
+ time_guidance_emb = timesteps_emb + guidance_emb
+
+ pooled_projections = self.text_embedder(pooled_projection)
+ conditioning = time_guidance_emb + pooled_projections
+
+ return conditioning
+
+
+class CogView3CombinedTimestepSizeEmbeddings(nn.Module):
+ def __init__(self, embedding_dim: int, condition_dim: int, pooled_projection_dim: int, timesteps_dim: int = 256):
+ super().__init__()
+
+ self.time_proj = Timesteps(num_channels=timesteps_dim, flip_sin_to_cos=True, downscale_freq_shift=0)
+ self.condition_proj = Timesteps(num_channels=condition_dim, flip_sin_to_cos=True, downscale_freq_shift=0)
+ self.timestep_embedder = TimestepEmbedding(in_channels=timesteps_dim, time_embed_dim=embedding_dim)
+ self.condition_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu")
+
+ def forward(
+ self,
+ timestep: torch.Tensor,
+ original_size: torch.Tensor,
+ target_size: torch.Tensor,
+ crop_coords: torch.Tensor,
+ hidden_dtype: torch.dtype,
+ ) -> torch.Tensor:
+ timesteps_proj = self.time_proj(timestep)
+
+ original_size_proj = self.condition_proj(original_size.flatten()).view(original_size.size(0), -1)
+ crop_coords_proj = self.condition_proj(crop_coords.flatten()).view(crop_coords.size(0), -1)
+ target_size_proj = self.condition_proj(target_size.flatten()).view(target_size.size(0), -1)
+
+ # (B, 3 * condition_dim)
+ condition_proj = torch.cat([original_size_proj, crop_coords_proj, target_size_proj], dim=1)
+
+ timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (B, embedding_dim)
+ condition_emb = self.condition_embedder(condition_proj.to(dtype=hidden_dtype)) # (B, embedding_dim)
+
+ conditioning = timesteps_emb + condition_emb
+ return conditioning
+
+
+class HunyuanDiTAttentionPool(nn.Module):
+ # Copied from https://github.com/Tencent/HunyuanDiT/blob/cb709308d92e6c7e8d59d0dff41b74d35088db6a/hydit/modules/poolers.py#L6
+
+ def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
+ super().__init__()
+ self.positional_embedding = nn.Parameter(torch.randn(spacial_dim + 1, embed_dim) / embed_dim**0.5)
+ self.k_proj = nn.Linear(embed_dim, embed_dim)
+ self.q_proj = nn.Linear(embed_dim, embed_dim)
+ self.v_proj = nn.Linear(embed_dim, embed_dim)
+ self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
+ self.num_heads = num_heads
+
+ def forward(self, x):
+ x = x.permute(1, 0, 2) # NLC -> LNC
+ x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (L+1)NC
+ x = x + self.positional_embedding[:, None, :].to(x.dtype) # (L+1)NC
+ x, _ = F.multi_head_attention_forward(
+ query=x[:1],
+ key=x,
+ value=x,
+ embed_dim_to_check=x.shape[-1],
+ num_heads=self.num_heads,
+ q_proj_weight=self.q_proj.weight,
+ k_proj_weight=self.k_proj.weight,
+ v_proj_weight=self.v_proj.weight,
+ in_proj_weight=None,
+ in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
+ bias_k=None,
+ bias_v=None,
+ add_zero_attn=False,
+ dropout_p=0,
+ out_proj_weight=self.c_proj.weight,
+ out_proj_bias=self.c_proj.bias,
+ use_separate_proj_weight=True,
+ training=self.training,
+ need_weights=False,
+ )
+ return x.squeeze(0)
+
+
+class HunyuanCombinedTimestepTextSizeStyleEmbedding(nn.Module):
+ def __init__(
+ self,
+ embedding_dim,
+ pooled_projection_dim=1024,
+ seq_len=256,
+ cross_attention_dim=2048,
+ use_style_cond_and_image_meta_size=True,
+ ):
+ super().__init__()
+
+ self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
+ self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
+
+ self.size_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
+
+ self.pooler = HunyuanDiTAttentionPool(
+ seq_len, cross_attention_dim, num_heads=8, output_dim=pooled_projection_dim
+ )
+
+ # Here we use a default learned embedder layer for future extension.
+ self.use_style_cond_and_image_meta_size = use_style_cond_and_image_meta_size
+ if use_style_cond_and_image_meta_size:
+ self.style_embedder = nn.Embedding(1, embedding_dim)
+ extra_in_dim = 256 * 6 + embedding_dim + pooled_projection_dim
+ else:
+ extra_in_dim = pooled_projection_dim
+
+ self.extra_embedder = PixArtAlphaTextProjection(
+ in_features=extra_in_dim,
+ hidden_size=embedding_dim * 4,
+ out_features=embedding_dim,
+ act_fn="silu_fp32",
+ )
+
+ def forward(self, timestep, encoder_hidden_states, image_meta_size, style, hidden_dtype=None):
+ timesteps_proj = self.time_proj(timestep)
+ timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, 256)
+
+ # extra condition1: text
+ pooled_projections = self.pooler(encoder_hidden_states) # (N, 1024)
+
+ if self.use_style_cond_and_image_meta_size:
+ # extra condition2: image meta size embedding
+ image_meta_size = self.size_proj(image_meta_size.view(-1))
+ image_meta_size = image_meta_size.to(dtype=hidden_dtype)
+ image_meta_size = image_meta_size.view(-1, 6 * 256) # (N, 1536)
+
+ # extra condition3: style embedding
+ style_embedding = self.style_embedder(style) # (N, embedding_dim)
+
+ # Concatenate all extra vectors
+ extra_cond = torch.cat([pooled_projections, image_meta_size, style_embedding], dim=1)
+ else:
+ extra_cond = torch.cat([pooled_projections], dim=1)
+
+ conditioning = timesteps_emb + self.extra_embedder(extra_cond) # [B, D]
+
+ return conditioning
+
+
+class LuminaCombinedTimestepCaptionEmbedding(nn.Module):
+ def __init__(self, hidden_size=4096, cross_attention_dim=2048, frequency_embedding_size=256):
+ super().__init__()
+ self.time_proj = Timesteps(
+ num_channels=frequency_embedding_size, flip_sin_to_cos=True, downscale_freq_shift=0.0
+ )
+
+ self.timestep_embedder = TimestepEmbedding(in_channels=frequency_embedding_size, time_embed_dim=hidden_size)
+
+ self.caption_embedder = nn.Sequential(
+ nn.LayerNorm(cross_attention_dim),
+ nn.Linear(
+ cross_attention_dim,
+ hidden_size,
+ bias=True,
+ ),
+ )
+
+ def forward(self, timestep, caption_feat, caption_mask):
+ # timestep embedding:
+ time_freq = self.time_proj(timestep)
+ time_embed = self.timestep_embedder(time_freq.to(dtype=self.timestep_embedder.linear_1.weight.dtype))
+
+ # caption condition embedding:
+ caption_mask_float = caption_mask.float().unsqueeze(-1)
+ caption_feats_pool = (caption_feat * caption_mask_float).sum(dim=1) / caption_mask_float.sum(dim=1)
+ caption_feats_pool = caption_feats_pool.to(caption_feat)
+ caption_embed = self.caption_embedder(caption_feats_pool)
+
+ conditioning = time_embed + caption_embed
+
+ return conditioning
+
+
+class TextTimeEmbedding(nn.Module):
+ def __init__(self, encoder_dim: int, time_embed_dim: int, num_heads: int = 64):
+ super().__init__()
+ self.norm1 = nn.LayerNorm(encoder_dim)
+ self.pool = AttentionPooling(num_heads, encoder_dim)
+ self.proj = nn.Linear(encoder_dim, time_embed_dim)
+ self.norm2 = nn.LayerNorm(time_embed_dim)
+
+ def forward(self, hidden_states):
+ hidden_states = self.norm1(hidden_states)
+ hidden_states = self.pool(hidden_states)
+ hidden_states = self.proj(hidden_states)
+ hidden_states = self.norm2(hidden_states)
+ return hidden_states
+
+
+class TextImageTimeEmbedding(nn.Module):
+ def __init__(self, text_embed_dim: int = 768, image_embed_dim: int = 768, time_embed_dim: int = 1536):
+ super().__init__()
+ self.text_proj = nn.Linear(text_embed_dim, time_embed_dim)
+ self.text_norm = nn.LayerNorm(time_embed_dim)
+ self.image_proj = nn.Linear(image_embed_dim, time_embed_dim)
+
+ def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor):
+ # text
+ time_text_embeds = self.text_proj(text_embeds)
+ time_text_embeds = self.text_norm(time_text_embeds)
+
+ # image
+ time_image_embeds = self.image_proj(image_embeds)
+
+ return time_image_embeds + time_text_embeds
+
+
+class ImageTimeEmbedding(nn.Module):
+ def __init__(self, image_embed_dim: int = 768, time_embed_dim: int = 1536):
+ super().__init__()
+ self.image_proj = nn.Linear(image_embed_dim, time_embed_dim)
+ self.image_norm = nn.LayerNorm(time_embed_dim)
+
+ def forward(self, image_embeds: torch.Tensor):
+ # image
+ time_image_embeds = self.image_proj(image_embeds)
+ time_image_embeds = self.image_norm(time_image_embeds)
+ return time_image_embeds
+
+
+class ImageHintTimeEmbedding(nn.Module):
+ def __init__(self, image_embed_dim: int = 768, time_embed_dim: int = 1536):
+ super().__init__()
+ self.image_proj = nn.Linear(image_embed_dim, time_embed_dim)
+ self.image_norm = nn.LayerNorm(time_embed_dim)
+ self.input_hint_block = nn.Sequential(
+ nn.Conv2d(3, 16, 3, padding=1),
+ nn.SiLU(),
+ nn.Conv2d(16, 16, 3, padding=1),
+ nn.SiLU(),
+ nn.Conv2d(16, 32, 3, padding=1, stride=2),
+ nn.SiLU(),
+ nn.Conv2d(32, 32, 3, padding=1),
+ nn.SiLU(),
+ nn.Conv2d(32, 96, 3, padding=1, stride=2),
+ nn.SiLU(),
+ nn.Conv2d(96, 96, 3, padding=1),
+ nn.SiLU(),
+ nn.Conv2d(96, 256, 3, padding=1, stride=2),
+ nn.SiLU(),
+ nn.Conv2d(256, 4, 3, padding=1),
+ )
+
+ def forward(self, image_embeds: torch.Tensor, hint: torch.Tensor):
+ # image
+ time_image_embeds = self.image_proj(image_embeds)
+ time_image_embeds = self.image_norm(time_image_embeds)
+ hint = self.input_hint_block(hint)
+ return time_image_embeds, hint
+
+
+class AttentionPooling(nn.Module):
+ # Copied from https://github.com/deep-floyd/IF/blob/2f91391f27dd3c468bf174be5805b4cc92980c0b/deepfloyd_if/model/nn.py#L54
+
+ def __init__(self, num_heads, embed_dim, dtype=None):
+ super().__init__()
+ self.dtype = dtype
+ self.positional_embedding = nn.Parameter(torch.randn(1, embed_dim) / embed_dim**0.5)
+ self.k_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype)
+ self.q_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype)
+ self.v_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype)
+ self.num_heads = num_heads
+ self.dim_per_head = embed_dim // self.num_heads
+
+ def forward(self, x):
+ bs, length, width = x.size()
+
+ def shape(x):
+ # (bs, length, width) --> (bs, length, n_heads, dim_per_head)
+ x = x.view(bs, -1, self.num_heads, self.dim_per_head)
+ # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
+ x = x.transpose(1, 2)
+ # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
+ x = x.reshape(bs * self.num_heads, -1, self.dim_per_head)
+ # (bs*n_heads, length, dim_per_head) --> (bs*n_heads, dim_per_head, length)
+ x = x.transpose(1, 2)
+ return x
+
+ class_token = x.mean(dim=1, keepdim=True) + self.positional_embedding.to(x.dtype)
+ x = torch.cat([class_token, x], dim=1) # (bs, length+1, width)
+
+ # (bs*n_heads, class_token_length, dim_per_head)
+ q = shape(self.q_proj(class_token))
+ # (bs*n_heads, length+class_token_length, dim_per_head)
+ k = shape(self.k_proj(x))
+ v = shape(self.v_proj(x))
+
+ # (bs*n_heads, class_token_length, length+class_token_length):
+ scale = 1 / math.sqrt(math.sqrt(self.dim_per_head))
+ weight = torch.einsum("bct,bcs->bts", q * scale, k * scale) # More stable with f16 than dividing afterwards
+ weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
+
+ # (bs*n_heads, dim_per_head, class_token_length)
+ a = torch.einsum("bts,bcs->bct", weight, v)
+
+ # (bs, length+1, width)
+ a = a.reshape(bs, -1, 1).transpose(1, 2)
+
+ return a[:, 0, :] # cls_token
+
+
+def get_fourier_embeds_from_boundingbox(embed_dim, box):
+ """
+ Args:
+ embed_dim: int
+ box: a 3-D tensor [B x N x 4] representing the bounding boxes for GLIGEN pipeline
+ Returns:
+ [B x N x embed_dim] tensor of positional embeddings
+ """
+
+ batch_size, num_boxes = box.shape[:2]
+
+ emb = 100 ** (torch.arange(embed_dim) / embed_dim)
+ emb = emb[None, None, None].to(device=box.device, dtype=box.dtype)
+ emb = emb * box.unsqueeze(-1)
+
+ emb = torch.stack((emb.sin(), emb.cos()), dim=-1)
+ emb = emb.permute(0, 1, 3, 4, 2).reshape(batch_size, num_boxes, embed_dim * 2 * 4)
+
+ return emb
+
+
+class GLIGENTextBoundingboxProjection(nn.Module):
+ def __init__(self, positive_len, out_dim, feature_type="text-only", fourier_freqs=8):
+ super().__init__()
+ self.positive_len = positive_len
+ self.out_dim = out_dim
+
+ self.fourier_embedder_dim = fourier_freqs
+ self.position_dim = fourier_freqs * 2 * 4 # 2: sin/cos, 4: xyxy
+
+ if isinstance(out_dim, tuple):
+ out_dim = out_dim[0]
+
+ if feature_type == "text-only":
+ self.linears = nn.Sequential(
+ nn.Linear(self.positive_len + self.position_dim, 512),
+ nn.SiLU(),
+ nn.Linear(512, 512),
+ nn.SiLU(),
+ nn.Linear(512, out_dim),
+ )
+ self.null_positive_feature = torch.nn.Parameter(torch.zeros([self.positive_len]))
+
+ elif feature_type == "text-image":
+ self.linears_text = nn.Sequential(
+ nn.Linear(self.positive_len + self.position_dim, 512),
+ nn.SiLU(),
+ nn.Linear(512, 512),
+ nn.SiLU(),
+ nn.Linear(512, out_dim),
+ )
+ self.linears_image = nn.Sequential(
+ nn.Linear(self.positive_len + self.position_dim, 512),
+ nn.SiLU(),
+ nn.Linear(512, 512),
+ nn.SiLU(),
+ nn.Linear(512, out_dim),
+ )
+ self.null_text_feature = torch.nn.Parameter(torch.zeros([self.positive_len]))
+ self.null_image_feature = torch.nn.Parameter(torch.zeros([self.positive_len]))
+
+ self.null_position_feature = torch.nn.Parameter(torch.zeros([self.position_dim]))
+
+ def forward(
+ self,
+ boxes,
+ masks,
+ positive_embeddings=None,
+ phrases_masks=None,
+ image_masks=None,
+ phrases_embeddings=None,
+ image_embeddings=None,
+ ):
+ masks = masks.unsqueeze(-1)
+
+ # embedding position (it may includes padding as placeholder)
+ xyxy_embedding = get_fourier_embeds_from_boundingbox(self.fourier_embedder_dim, boxes) # B*N*4 -> B*N*C
+
+ # learnable null embedding
+ xyxy_null = self.null_position_feature.view(1, 1, -1)
+
+ # replace padding with learnable null embedding
+ xyxy_embedding = xyxy_embedding * masks + (1 - masks) * xyxy_null
+
+ # positionet with text only information
+ if positive_embeddings is not None:
+ # learnable null embedding
+ positive_null = self.null_positive_feature.view(1, 1, -1)
+
+ # replace padding with learnable null embedding
+ positive_embeddings = positive_embeddings * masks + (1 - masks) * positive_null
+
+ objs = self.linears(torch.cat([positive_embeddings, xyxy_embedding], dim=-1))
+
+ # positionet with text and image information
+ else:
+ phrases_masks = phrases_masks.unsqueeze(-1)
+ image_masks = image_masks.unsqueeze(-1)
+
+ # learnable null embedding
+ text_null = self.null_text_feature.view(1, 1, -1)
+ image_null = self.null_image_feature.view(1, 1, -1)
+
+ # replace padding with learnable null embedding
+ phrases_embeddings = phrases_embeddings * phrases_masks + (1 - phrases_masks) * text_null
+ image_embeddings = image_embeddings * image_masks + (1 - image_masks) * image_null
+
+ objs_text = self.linears_text(torch.cat([phrases_embeddings, xyxy_embedding], dim=-1))
+ objs_image = self.linears_image(torch.cat([image_embeddings, xyxy_embedding], dim=-1))
+ objs = torch.cat([objs_text, objs_image], dim=1)
+
+ return objs
+
+
+class PixArtAlphaCombinedTimestepSizeEmbeddings(nn.Module):
+ """
+ For PixArt-Alpha.
+
+ Reference:
+ https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L164C9-L168C29
+ """
+
+ def __init__(self, embedding_dim, size_emb_dim, use_additional_conditions: bool = False):
+ super().__init__()
+
+ self.outdim = size_emb_dim
+ self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
+ self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
+
+ self.use_additional_conditions = use_additional_conditions
+ if use_additional_conditions:
+ self.additional_condition_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
+ self.resolution_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim)
+ self.aspect_ratio_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim)
+
+ def forward(self, timestep, resolution, aspect_ratio, batch_size, hidden_dtype):
+ timesteps_proj = self.time_proj(timestep)
+ timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D)
+
+ if self.use_additional_conditions:
+ resolution_emb = self.additional_condition_proj(resolution.flatten()).to(hidden_dtype)
+ resolution_emb = self.resolution_embedder(resolution_emb).reshape(batch_size, -1)
+ aspect_ratio_emb = self.additional_condition_proj(aspect_ratio.flatten()).to(hidden_dtype)
+ aspect_ratio_emb = self.aspect_ratio_embedder(aspect_ratio_emb).reshape(batch_size, -1)
+ conditioning = timesteps_emb + torch.cat([resolution_emb, aspect_ratio_emb], dim=1)
+ else:
+ conditioning = timesteps_emb
+
+ return conditioning
+
+
+class PixArtAlphaTextProjection(nn.Module):
+ """
+ Projects caption embeddings. Also handles dropout for classifier-free guidance.
+
+ Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py
+ """
+
+ def __init__(self, in_features, hidden_size, out_features=None, act_fn="gelu_tanh"):
+ super().__init__()
+ if out_features is None:
+ out_features = hidden_size
+ self.linear_1 = nn.Linear(in_features=in_features, out_features=hidden_size, bias=True)
+ if act_fn == "gelu_tanh":
+ self.act_1 = nn.GELU(approximate="tanh")
+ elif act_fn == "silu":
+ self.act_1 = nn.SiLU()
+ elif act_fn == "silu_fp32":
+ self.act_1 = FP32SiLU()
+ else:
+ raise ValueError(f"Unknown activation function: {act_fn}")
+ self.linear_2 = nn.Linear(in_features=hidden_size, out_features=out_features, bias=True)
+
+ def forward(self, caption):
+ hidden_states = self.linear_1(caption)
+ hidden_states = self.act_1(hidden_states)
+ hidden_states = self.linear_2(hidden_states)
+ return hidden_states
+
+
+class IPAdapterPlusImageProjectionBlock(nn.Module):
+ def __init__(
+ self,
+ embed_dims: int = 768,
+ dim_head: int = 64,
+ heads: int = 16,
+ ffn_ratio: float = 4,
+ ) -> None:
+ super().__init__()
+ from .attention import FeedForward
+
+ self.ln0 = nn.LayerNorm(embed_dims)
+ self.ln1 = nn.LayerNorm(embed_dims)
+ self.attn = Attention(
+ query_dim=embed_dims,
+ dim_head=dim_head,
+ heads=heads,
+ out_bias=False,
+ )
+ self.ff = nn.Sequential(
+ nn.LayerNorm(embed_dims),
+ FeedForward(embed_dims, embed_dims, activation_fn="gelu", mult=ffn_ratio, bias=False),
+ )
+
+ def forward(self, x, latents, residual):
+ encoder_hidden_states = self.ln0(x)
+ latents = self.ln1(latents)
+ encoder_hidden_states = torch.cat([encoder_hidden_states, latents], dim=-2)
+ latents = self.attn(latents, encoder_hidden_states) + residual
+ latents = self.ff(latents) + latents
+ return latents
+
+
+class IPAdapterPlusImageProjection(nn.Module):
+ """Resampler of IP-Adapter Plus.
+
+ Args:
+ embed_dims (int): The feature dimension. Defaults to 768. output_dims (int): The number of output channels,
+ that is the same
+ number of the channels in the `unet.config.cross_attention_dim`. Defaults to 1024.
+ hidden_dims (int):
+ The number of hidden channels. Defaults to 1280. depth (int): The number of blocks. Defaults
+ to 8. dim_head (int): The number of head channels. Defaults to 64. heads (int): Parallel attention heads.
+ Defaults to 16. num_queries (int):
+ The number of queries. Defaults to 8. ffn_ratio (float): The expansion ratio
+ of feedforward network hidden
+ layer channels. Defaults to 4.
+ """
+
+ def __init__(
+ self,
+ embed_dims: int = 768,
+ output_dims: int = 1024,
+ hidden_dims: int = 1280,
+ depth: int = 4,
+ dim_head: int = 64,
+ heads: int = 16,
+ num_queries: int = 8,
+ ffn_ratio: float = 4,
+ ) -> None:
+ super().__init__()
+ self.latents = nn.Parameter(torch.randn(1, num_queries, hidden_dims) / hidden_dims**0.5)
+
+ self.proj_in = nn.Linear(embed_dims, hidden_dims)
+
+ self.proj_out = nn.Linear(hidden_dims, output_dims)
+ self.norm_out = nn.LayerNorm(output_dims)
+
+ self.layers = nn.ModuleList(
+ [IPAdapterPlusImageProjectionBlock(hidden_dims, dim_head, heads, ffn_ratio) for _ in range(depth)]
+ )
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """Forward pass.
+
+ Args:
+ x (torch.Tensor): Input Tensor.
+ Returns:
+ torch.Tensor: Output Tensor.
+ """
+ latents = self.latents.repeat(x.size(0), 1, 1)
+
+ x = self.proj_in(x)
+
+ for block in self.layers:
+ residual = latents
+ latents = block(x, latents, residual)
+
+ latents = self.proj_out(latents)
+ return self.norm_out(latents)
+
+
+class IPAdapterFaceIDPlusImageProjection(nn.Module):
+ """FacePerceiverResampler of IP-Adapter Plus.
+
+ Args:
+ embed_dims (int): The feature dimension. Defaults to 768. output_dims (int): The number of output channels,
+ that is the same
+ number of the channels in the `unet.config.cross_attention_dim`. Defaults to 1024.
+ hidden_dims (int):
+ The number of hidden channels. Defaults to 1280. depth (int): The number of blocks. Defaults
+ to 8. dim_head (int): The number of head channels. Defaults to 64. heads (int): Parallel attention heads.
+ Defaults to 16. num_tokens (int): Number of tokens num_queries (int): The number of queries. Defaults to 8.
+ ffn_ratio (float): The expansion ratio of feedforward network hidden
+ layer channels. Defaults to 4.
+ ffproj_ratio (float): The expansion ratio of feedforward network hidden
+ layer channels (for ID embeddings). Defaults to 4.
+ """
+
+ def __init__(
+ self,
+ embed_dims: int = 768,
+ output_dims: int = 768,
+ hidden_dims: int = 1280,
+ id_embeddings_dim: int = 512,
+ depth: int = 4,
+ dim_head: int = 64,
+ heads: int = 16,
+ num_tokens: int = 4,
+ num_queries: int = 8,
+ ffn_ratio: float = 4,
+ ffproj_ratio: int = 2,
+ ) -> None:
+ super().__init__()
+ from .attention import FeedForward
+
+ self.num_tokens = num_tokens
+ self.embed_dim = embed_dims
+ self.clip_embeds = None
+ self.shortcut = False
+ self.shortcut_scale = 1.0
+
+ self.proj = FeedForward(id_embeddings_dim, embed_dims * num_tokens, activation_fn="gelu", mult=ffproj_ratio)
+ self.norm = nn.LayerNorm(embed_dims)
+
+ self.proj_in = nn.Linear(hidden_dims, embed_dims)
+
+ self.proj_out = nn.Linear(embed_dims, output_dims)
+ self.norm_out = nn.LayerNorm(output_dims)
+
+ self.layers = nn.ModuleList(
+ [IPAdapterPlusImageProjectionBlock(embed_dims, dim_head, heads, ffn_ratio) for _ in range(depth)]
+ )
+
+ def forward(self, id_embeds: torch.Tensor) -> torch.Tensor:
+ """Forward pass.
+
+ Args:
+ id_embeds (torch.Tensor): Input Tensor (ID embeds).
+ Returns:
+ torch.Tensor: Output Tensor.
+ """
+ id_embeds = id_embeds.to(self.clip_embeds.dtype)
+ id_embeds = self.proj(id_embeds)
+ id_embeds = id_embeds.reshape(-1, self.num_tokens, self.embed_dim)
+ id_embeds = self.norm(id_embeds)
+ latents = id_embeds
+
+ clip_embeds = self.proj_in(self.clip_embeds)
+ x = clip_embeds.reshape(-1, clip_embeds.shape[2], clip_embeds.shape[3])
+
+ for block in self.layers:
+ residual = latents
+ latents = block(x, latents, residual)
+
+ latents = self.proj_out(latents)
+ out = self.norm_out(latents)
+ if self.shortcut:
+ out = id_embeds + self.shortcut_scale * out
+ return out
+
+
+class MultiIPAdapterImageProjection(nn.Module):
+ def __init__(self, IPAdapterImageProjectionLayers: Union[List[nn.Module], Tuple[nn.Module]]):
+ super().__init__()
+ self.image_projection_layers = nn.ModuleList(IPAdapterImageProjectionLayers)
+
+ def forward(self, image_embeds: List[torch.Tensor]):
+ projected_image_embeds = []
+
+ # currently, we accept `image_embeds` as
+ # 1. a tensor (deprecated) with shape [batch_size, embed_dim] or [batch_size, sequence_length, embed_dim]
+ # 2. list of `n` tensors where `n` is number of ip-adapters, each tensor can hae shape [batch_size, num_images, embed_dim] or [batch_size, num_images, sequence_length, embed_dim]
+ if not isinstance(image_embeds, list):
+ deprecation_message = (
+ "You have passed a tensor as `image_embeds`.This is deprecated and will be removed in a future release."
+ " Please make sure to update your script to pass `image_embeds` as a list of tensors to suppress this warning."
+ )
+ deprecate("image_embeds not a list", "1.0.0", deprecation_message, standard_warn=False)
+ image_embeds = [image_embeds.unsqueeze(1)]
+
+ if len(image_embeds) != len(self.image_projection_layers):
+ raise ValueError(
+ f"image_embeds must have the same length as image_projection_layers, got {len(image_embeds)} and {len(self.image_projection_layers)}"
+ )
+
+ for image_embed, image_projection_layer in zip(image_embeds, self.image_projection_layers):
+ batch_size, num_images = image_embed.shape[0], image_embed.shape[1]
+ image_embed = image_embed.reshape((batch_size * num_images,) + image_embed.shape[2:])
+ image_embed = image_projection_layer(image_embed)
+ image_embed = image_embed.reshape((batch_size, num_images) + image_embed.shape[1:])
+
+ projected_image_embeds.append(image_embed)
+
+ return projected_image_embeds
diff --git a/MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/cogvideox_5b/models/normalization.py b/MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/cogvideox_5b/models/normalization.py
new file mode 100644
index 0000000000000000000000000000000000000000..0d57bf23abb52c3c75f688860f32e5686366831b
--- /dev/null
+++ b/MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/cogvideox_5b/models/normalization.py
@@ -0,0 +1,530 @@
+# coding=utf-8
+# Copyright 2024 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import numbers
+from typing import Dict, Optional, Tuple
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from diffusers.utils import is_torch_version
+from .activations import get_activation
+from .embeddings import (
+ CombinedTimestepLabelEmbeddings,
+ PixArtAlphaCombinedTimestepSizeEmbeddings,
+)
+
+
+class AdaLayerNorm(nn.Module):
+ r"""
+ Norm layer modified to incorporate timestep embeddings.
+
+ Parameters:
+ embedding_dim (`int`): The size of each embedding vector.
+ num_embeddings (`int`, *optional*): The size of the embeddings dictionary.
+ output_dim (`int`, *optional*):
+ norm_elementwise_affine (`bool`, defaults to `False):
+ norm_eps (`bool`, defaults to `False`):
+ chunk_dim (`int`, defaults to `0`):
+ """
+
+ def __init__(
+ self,
+ embedding_dim: int,
+ num_embeddings: Optional[int] = None,
+ output_dim: Optional[int] = None,
+ norm_elementwise_affine: bool = False,
+ norm_eps: float = 1e-5,
+ chunk_dim: int = 0,
+ ):
+ super().__init__()
+
+ self.chunk_dim = chunk_dim
+ output_dim = output_dim or embedding_dim * 2
+
+ if num_embeddings is not None:
+ self.emb = nn.Embedding(num_embeddings, embedding_dim)
+ else:
+ self.emb = None
+
+ self.silu = nn.SiLU()
+ self.linear = nn.Linear(embedding_dim, output_dim)
+ self.norm = nn.LayerNorm(output_dim // 2, norm_eps, norm_elementwise_affine)
+
+ def forward(
+ self, x: torch.Tensor, timestep: Optional[torch.Tensor] = None, temb: Optional[torch.Tensor] = None
+ ) -> torch.Tensor:
+ if self.emb is not None:
+ temb = self.emb(timestep)
+
+ temb = self.linear(self.silu(temb))
+
+ if self.chunk_dim == 1:
+ # This is a bit weird why we have the order of "shift, scale" here and "scale, shift" in the
+ # other if-branch. This branch is specific to CogVideoX for now.
+ shift, scale = temb.chunk(2, dim=1)
+ shift = shift[:, None, :]
+ scale = scale[:, None, :]
+ else:
+ scale, shift = temb.chunk(2, dim=0)
+
+ x = self.norm(x) * (1 + scale) + shift
+ return x
+
+
+class FP32LayerNorm(nn.LayerNorm):
+ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
+ origin_dtype = inputs.dtype
+ return F.layer_norm(
+ inputs.float(),
+ self.normalized_shape,
+ self.weight.float() if self.weight is not None else None,
+ self.bias.float() if self.bias is not None else None,
+ self.eps,
+ ).to(origin_dtype)
+
+
+class SD35AdaLayerNormZeroX(nn.Module):
+ r"""
+ Norm layer adaptive layer norm zero (AdaLN-Zero).
+
+ Parameters:
+ embedding_dim (`int`): The size of each embedding vector.
+ num_embeddings (`int`): The size of the embeddings dictionary.
+ """
+
+ def __init__(self, embedding_dim: int, norm_type: str = "layer_norm", bias: bool = True) -> None:
+ super().__init__()
+
+ self.silu = nn.SiLU()
+ self.linear = nn.Linear(embedding_dim, 9 * embedding_dim, bias=bias)
+ if norm_type == "layer_norm":
+ self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
+ else:
+ raise ValueError(f"Unsupported `norm_type` ({norm_type}) provided. Supported ones are: 'layer_norm'.")
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ emb: Optional[torch.Tensor] = None,
+ ) -> Tuple[torch.Tensor, ...]:
+ emb = self.linear(self.silu(emb))
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp, shift_msa2, scale_msa2, gate_msa2 = emb.chunk(
+ 9, dim=1
+ )
+ norm_hidden_states = self.norm(hidden_states)
+ hidden_states = norm_hidden_states * (1 + scale_msa[:, None]) + shift_msa[:, None]
+ norm_hidden_states2 = norm_hidden_states * (1 + scale_msa2[:, None]) + shift_msa2[:, None]
+ return hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp, norm_hidden_states2, gate_msa2
+
+
+class AdaLayerNormZero(nn.Module):
+ r"""
+ Norm layer adaptive layer norm zero (adaLN-Zero).
+
+ Parameters:
+ embedding_dim (`int`): The size of each embedding vector.
+ num_embeddings (`int`): The size of the embeddings dictionary.
+ """
+
+ def __init__(self, embedding_dim: int, num_embeddings: Optional[int] = None, norm_type="layer_norm", bias=True):
+ super().__init__()
+ if num_embeddings is not None:
+ self.emb = CombinedTimestepLabelEmbeddings(num_embeddings, embedding_dim)
+ else:
+ self.emb = None
+
+ self.silu = nn.SiLU()
+ self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=bias)
+ if norm_type == "layer_norm":
+ self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
+ elif norm_type == "fp32_layer_norm":
+ self.norm = FP32LayerNorm(embedding_dim, elementwise_affine=False, bias=False)
+ else:
+ raise ValueError(
+ f"Unsupported `norm_type` ({norm_type}) provided. Supported ones are: 'layer_norm', 'fp32_layer_norm'."
+ )
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ timestep: Optional[torch.Tensor] = None,
+ class_labels: Optional[torch.LongTensor] = None,
+ hidden_dtype: Optional[torch.dtype] = None,
+ emb: Optional[torch.Tensor] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ if self.emb is not None:
+ emb = self.emb(timestep, class_labels, hidden_dtype=hidden_dtype)
+ emb = self.linear(self.silu(emb))
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, dim=1)
+ x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
+ return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
+
+
+class AdaLayerNormZeroSingle(nn.Module):
+ r"""
+ Norm layer adaptive layer norm zero (adaLN-Zero).
+
+ Parameters:
+ embedding_dim (`int`): The size of each embedding vector.
+ num_embeddings (`int`): The size of the embeddings dictionary.
+ """
+
+ def __init__(self, embedding_dim: int, norm_type="layer_norm", bias=True):
+ super().__init__()
+
+ self.silu = nn.SiLU()
+ self.linear = nn.Linear(embedding_dim, 3 * embedding_dim, bias=bias)
+ if norm_type == "layer_norm":
+ self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
+ else:
+ raise ValueError(
+ f"Unsupported `norm_type` ({norm_type}) provided. Supported ones are: 'layer_norm', 'fp32_layer_norm'."
+ )
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ emb: Optional[torch.Tensor] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ emb = self.linear(self.silu(emb))
+ shift_msa, scale_msa, gate_msa = emb.chunk(3, dim=1)
+ x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
+ return x, gate_msa
+
+
+class LuminaRMSNormZero(nn.Module):
+ """
+ Norm layer adaptive RMS normalization zero.
+
+ Parameters:
+ embedding_dim (`int`): The size of each embedding vector.
+ """
+
+ def __init__(self, embedding_dim: int, norm_eps: float, norm_elementwise_affine: bool):
+ super().__init__()
+ self.silu = nn.SiLU()
+ self.linear = nn.Linear(
+ min(embedding_dim, 1024),
+ 4 * embedding_dim,
+ bias=True,
+ )
+ self.norm = RMSNorm(embedding_dim, eps=norm_eps, elementwise_affine=norm_elementwise_affine)
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ emb: Optional[torch.Tensor] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ # emb = self.emb(timestep, encoder_hidden_states, encoder_mask)
+ emb = self.linear(self.silu(emb))
+ scale_msa, gate_msa, scale_mlp, gate_mlp = emb.chunk(4, dim=1)
+ x = self.norm(x) * (1 + scale_msa[:, None])
+
+ return x, gate_msa, scale_mlp, gate_mlp
+
+
+class AdaLayerNormSingle(nn.Module):
+ r"""
+ Norm layer adaptive layer norm single (adaLN-single).
+
+ As proposed in PixArt-Alpha (see: https://arxiv.org/abs/2310.00426; Section 2.3).
+
+ Parameters:
+ embedding_dim (`int`): The size of each embedding vector.
+ use_additional_conditions (`bool`): To use additional conditions for normalization or not.
+ """
+
+ def __init__(self, embedding_dim: int, use_additional_conditions: bool = False):
+ super().__init__()
+
+ self.emb = PixArtAlphaCombinedTimestepSizeEmbeddings(
+ embedding_dim, size_emb_dim=embedding_dim // 3, use_additional_conditions=use_additional_conditions
+ )
+
+ self.silu = nn.SiLU()
+ self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True)
+
+ def forward(
+ self,
+ timestep: torch.Tensor,
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
+ batch_size: Optional[int] = None,
+ hidden_dtype: Optional[torch.dtype] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ # No modulation happening here.
+ embedded_timestep = self.emb(timestep, **added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_dtype)
+ return self.linear(self.silu(embedded_timestep)), embedded_timestep
+
+
+class AdaGroupNorm(nn.Module):
+ r"""
+ GroupNorm layer modified to incorporate timestep embeddings.
+
+ Parameters:
+ embedding_dim (`int`): The size of each embedding vector.
+ num_embeddings (`int`): The size of the embeddings dictionary.
+ num_groups (`int`): The number of groups to separate the channels into.
+ act_fn (`str`, *optional*, defaults to `None`): The activation function to use.
+ eps (`float`, *optional*, defaults to `1e-5`): The epsilon value to use for numerical stability.
+ """
+
+ def __init__(
+ self, embedding_dim: int, out_dim: int, num_groups: int, act_fn: Optional[str] = None, eps: float = 1e-5
+ ):
+ super().__init__()
+ self.num_groups = num_groups
+ self.eps = eps
+
+ if act_fn is None:
+ self.act = None
+ else:
+ self.act = get_activation(act_fn)
+
+ self.linear = nn.Linear(embedding_dim, out_dim * 2)
+
+ def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor:
+ if self.act:
+ emb = self.act(emb)
+ emb = self.linear(emb)
+ emb = emb[:, :, None, None]
+ scale, shift = emb.chunk(2, dim=1)
+
+ x = F.group_norm(x, self.num_groups, eps=self.eps)
+ x = x * (1 + scale) + shift
+ return x
+
+
+class AdaLayerNormContinuous(nn.Module):
+ def __init__(
+ self,
+ embedding_dim: int,
+ conditioning_embedding_dim: int,
+ # NOTE: It is a bit weird that the norm layer can be configured to have scale and shift parameters
+ # because the output is immediately scaled and shifted by the projected conditioning embeddings.
+ # Note that AdaLayerNorm does not let the norm layer have scale and shift parameters.
+ # However, this is how it was implemented in the original code, and it's rather likely you should
+ # set `elementwise_affine` to False.
+ elementwise_affine=True,
+ eps=1e-5,
+ bias=True,
+ norm_type="layer_norm",
+ ):
+ super().__init__()
+ self.silu = nn.SiLU()
+ self.linear = nn.Linear(conditioning_embedding_dim, embedding_dim * 2, bias=bias)
+ if norm_type == "layer_norm":
+ self.norm = LayerNorm(embedding_dim, eps, elementwise_affine, bias)
+ elif norm_type == "rms_norm":
+ self.norm = RMSNorm(embedding_dim, eps, elementwise_affine)
+ else:
+ raise ValueError(f"unknown norm_type {norm_type}")
+
+ def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor) -> torch.Tensor:
+ # convert back to the original dtype in case `conditioning_embedding`` is upcasted to float32 (needed for hunyuanDiT)
+ emb = self.linear(self.silu(conditioning_embedding).to(x.dtype))
+ scale, shift = torch.chunk(emb, 2, dim=1)
+ x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
+ return x
+
+
+class LuminaLayerNormContinuous(nn.Module):
+ def __init__(
+ self,
+ embedding_dim: int,
+ conditioning_embedding_dim: int,
+ # NOTE: It is a bit weird that the norm layer can be configured to have scale and shift parameters
+ # because the output is immediately scaled and shifted by the projected conditioning embeddings.
+ # Note that AdaLayerNorm does not let the norm layer have scale and shift parameters.
+ # However, this is how it was implemented in the original code, and it's rather likely you should
+ # set `elementwise_affine` to False.
+ elementwise_affine=True,
+ eps=1e-5,
+ bias=True,
+ norm_type="layer_norm",
+ out_dim: Optional[int] = None,
+ ):
+ super().__init__()
+ # AdaLN
+ self.silu = nn.SiLU()
+ self.linear_1 = nn.Linear(conditioning_embedding_dim, embedding_dim, bias=bias)
+ if norm_type == "layer_norm":
+ self.norm = LayerNorm(embedding_dim, eps, elementwise_affine, bias)
+ else:
+ raise ValueError(f"unknown norm_type {norm_type}")
+ # linear_2
+ if out_dim is not None:
+ self.linear_2 = nn.Linear(
+ embedding_dim,
+ out_dim,
+ bias=bias,
+ )
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ conditioning_embedding: torch.Tensor,
+ ) -> torch.Tensor:
+ # convert back to the original dtype in case `conditioning_embedding`` is upcasted to float32 (needed for hunyuanDiT)
+ emb = self.linear_1(self.silu(conditioning_embedding).to(x.dtype))
+ scale = emb
+ x = self.norm(x) * (1 + scale)[:, None, :]
+
+ if self.linear_2 is not None:
+ x = self.linear_2(x)
+
+ return x
+
+
+class CogView3PlusAdaLayerNormZeroTextImage(nn.Module):
+ r"""
+ Norm layer adaptive layer norm zero (adaLN-Zero).
+
+ Parameters:
+ embedding_dim (`int`): The size of each embedding vector.
+ num_embeddings (`int`): The size of the embeddings dictionary.
+ """
+
+ def __init__(self, embedding_dim: int, dim: int):
+ super().__init__()
+
+ self.silu = nn.SiLU()
+ self.linear = nn.Linear(embedding_dim, 12 * dim, bias=True)
+ self.norm_x = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5)
+ self.norm_c = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5)
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ context: torch.Tensor,
+ emb: Optional[torch.Tensor] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ emb = self.linear(self.silu(emb))
+ (
+ shift_msa,
+ scale_msa,
+ gate_msa,
+ shift_mlp,
+ scale_mlp,
+ gate_mlp,
+ c_shift_msa,
+ c_scale_msa,
+ c_gate_msa,
+ c_shift_mlp,
+ c_scale_mlp,
+ c_gate_mlp,
+ ) = emb.chunk(12, dim=1)
+ normed_x = self.norm_x(x)
+ normed_context = self.norm_c(context)
+ x = normed_x * (1 + scale_msa[:, None]) + shift_msa[:, None]
+ context = normed_context * (1 + c_scale_msa[:, None]) + c_shift_msa[:, None]
+ return x, gate_msa, shift_mlp, scale_mlp, gate_mlp, context, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp
+
+
+class CogVideoXLayerNormZero(nn.Module):
+ def __init__(
+ self,
+ conditioning_dim: int,
+ embedding_dim: int,
+ elementwise_affine: bool = True,
+ eps: float = 1e-5,
+ bias: bool = True,
+ ) -> None:
+ super().__init__()
+
+ self.silu = nn.SiLU()
+ self.linear = nn.Linear(conditioning_dim, 6 * embedding_dim, bias=bias)
+ self.norm = nn.LayerNorm(embedding_dim, eps=eps, elementwise_affine=elementwise_affine)
+
+ def forward(
+ self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, temb: torch.Tensor
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ shift, scale, gate, enc_shift, enc_scale, enc_gate = self.linear(self.silu(temb)).chunk(6, dim=1)
+ hidden_states = self.norm(hidden_states) * (1 + scale)[:, None, :] + shift[:, None, :]
+ encoder_hidden_states = self.norm(encoder_hidden_states) * (1 + enc_scale)[:, None, :] + enc_shift[:, None, :]
+ return hidden_states, encoder_hidden_states, gate[:, None, :], enc_gate[:, None, :]
+
+
+if is_torch_version(">=", "2.1.0"):
+ LayerNorm = nn.LayerNorm
+else:
+ # Has optional bias parameter compared to torch layer norm
+ # TODO: replace with torch layernorm once min required torch version >= 2.1
+ class LayerNorm(nn.Module):
+ def __init__(self, dim, eps: float = 1e-5, elementwise_affine: bool = True, bias: bool = True):
+ super().__init__()
+
+ self.eps = eps
+
+ if isinstance(dim, numbers.Integral):
+ dim = (dim,)
+
+ self.dim = torch.Size(dim)
+
+ if elementwise_affine:
+ self.weight = nn.Parameter(torch.ones(dim))
+ self.bias = nn.Parameter(torch.zeros(dim)) if bias else None
+ else:
+ self.weight = None
+ self.bias = None
+
+ def forward(self, input):
+ return F.layer_norm(input, self.dim, self.weight, self.bias, self.eps)
+
+
+class RMSNorm(nn.Module):
+ def __init__(self, dim, eps: float, elementwise_affine: bool = True):
+ super().__init__()
+
+ self.eps = eps
+
+ if isinstance(dim, numbers.Integral):
+ dim = (dim,)
+
+ self.dim = torch.Size(dim)
+
+ if elementwise_affine:
+ self.weight = nn.Parameter(torch.ones(dim))
+ else:
+ self.weight = None
+
+ def forward(self, hidden_states):
+ input_dtype = hidden_states.dtype
+ variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
+
+ if self.weight is not None:
+ # convert into half-precision if necessary
+ if self.weight.dtype in [torch.float16, torch.bfloat16]:
+ hidden_states = hidden_states.to(self.weight.dtype)
+ hidden_states = hidden_states * self.weight
+ else:
+ hidden_states = hidden_states.to(input_dtype)
+
+ return hidden_states
+
+
+class GlobalResponseNorm(nn.Module):
+ # Taken from https://github.com/facebookresearch/ConvNeXt-V2/blob/3608f67cc1dae164790c5d0aead7bf2d73d9719b/models/utils.py#L105
+ def __init__(self, dim):
+ super().__init__()
+ self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim))
+ self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim))
+
+ def forward(self, x):
+ gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True)
+ nx = gx / (gx.mean(dim=-1, keepdim=True) + 1e-6)
+ return self.gamma * (x * nx) + self.beta + x
diff --git a/MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/cogvideox_5b/models/transformers/__init__.py b/MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/cogvideox_5b/models/transformers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..d5899e9cd745185db78b1d5846f695978b32eee1
--- /dev/null
+++ b/MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/cogvideox_5b/models/transformers/__init__.py
@@ -0,0 +1 @@
+from .cogvideox_transformer_3d import CogVideoXTransformer3DModel
diff --git a/MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/cogvideox_5b/models/transformers/cogvideox_transformer_3d.py b/MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/cogvideox_5b/models/transformers/cogvideox_transformer_3d.py
new file mode 100644
index 0000000000000000000000000000000000000000..0cdc4a0cdd8008862bfd69f37abcc1e0156b0098
--- /dev/null
+++ b/MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/cogvideox_5b/models/transformers/cogvideox_transformer_3d.py
@@ -0,0 +1,506 @@
+# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
+# All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Any, Dict, Optional, Tuple, Union
+
+import torch
+from torch import nn
+
+from diffusers.configuration_utils import ConfigMixin, register_to_config
+from diffusers.loaders import PeftAdapterMixin
+from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
+from diffusers.utils.torch_utils import maybe_allow_in_graph
+from diffusers.models.modeling_outputs import Transformer2DModelOutput
+from diffusers.models.modeling_utils import ModelMixin
+from ..attention import Attention, FeedForward
+from ..attention_processor import AttentionProcessor, CogVideoXAttnProcessor2_0, FusedCogVideoXAttnProcessor2_0
+from ..embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps
+from ..normalization import AdaLayerNorm, CogVideoXLayerNormZero
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+@maybe_allow_in_graph
+class CogVideoXBlock(nn.Module):
+ r"""
+ Transformer block used in [CogVideoX](https://github.com/THUDM/CogVideo) model.
+
+ Parameters:
+ dim (`int`):
+ The number of channels in the input and output.
+ num_attention_heads (`int`):
+ The number of heads to use for multi-head attention.
+ attention_head_dim (`int`):
+ The number of channels in each head.
+ time_embed_dim (`int`):
+ The number of channels in timestep embedding.
+ dropout (`float`, defaults to `0.0`):
+ The dropout probability to use.
+ activation_fn (`str`, defaults to `"gelu-approximate"`):
+ Activation function to be used in feed-forward.
+ attention_bias (`bool`, defaults to `False`):
+ Whether or not to use bias in attention projection layers.
+ qk_norm (`bool`, defaults to `True`):
+ Whether or not to use normalization after query and key projections in Attention.
+ norm_elementwise_affine (`bool`, defaults to `True`):
+ Whether to use learnable elementwise affine parameters for normalization.
+ norm_eps (`float`, defaults to `1e-5`):
+ Epsilon value for normalization layers.
+ final_dropout (`bool` defaults to `False`):
+ Whether to apply a final dropout after the last feed-forward layer.
+ ff_inner_dim (`int`, *optional*, defaults to `None`):
+ Custom hidden dimension of Feed-forward layer. If not provided, `4 * dim` is used.
+ ff_bias (`bool`, defaults to `True`):
+ Whether or not to use bias in Feed-forward layer.
+ attention_out_bias (`bool`, defaults to `True`):
+ Whether or not to use bias in Attention output projection layer.
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ num_attention_heads: int,
+ attention_head_dim: int,
+ time_embed_dim: int,
+ dropout: float = 0.0,
+ activation_fn: str = "gelu-approximate",
+ attention_bias: bool = False,
+ qk_norm: bool = True,
+ norm_elementwise_affine: bool = True,
+ norm_eps: float = 1e-5,
+ final_dropout: bool = True,
+ ff_inner_dim: Optional[int] = None,
+ ff_bias: bool = True,
+ attention_out_bias: bool = True,
+ ):
+ super().__init__()
+
+ # 1. Self Attention
+ self.norm1 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
+
+ self.attn1 = Attention(
+ query_dim=dim,
+ dim_head=attention_head_dim,
+ heads=num_attention_heads,
+ qk_norm="layer_norm" if qk_norm else None,
+ eps=1e-6,
+ bias=attention_bias,
+ out_bias=attention_out_bias,
+ processor=CogVideoXAttnProcessor2_0(),
+ )
+
+ # 2. Feed Forward
+ self.norm2 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
+
+ self.ff = FeedForward(
+ dim,
+ dropout=dropout,
+ activation_fn=activation_fn,
+ final_dropout=final_dropout,
+ inner_dim=ff_inner_dim,
+ bias=ff_bias,
+ )
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ temb: torch.Tensor,
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ ) -> torch.Tensor:
+ text_seq_length = encoder_hidden_states.size(1)
+ # norm & modulate
+ norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1(
+ hidden_states, encoder_hidden_states, temb
+ )
+
+ # attention
+ attn_hidden_states, attn_encoder_hidden_states = self.attn1(
+ hidden_states=norm_hidden_states,
+ encoder_hidden_states=norm_encoder_hidden_states,
+ image_rotary_emb=image_rotary_emb,
+ )
+
+ hidden_states = hidden_states + gate_msa * attn_hidden_states
+ encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_encoder_hidden_states
+
+ # norm & modulate
+ norm_hidden_states, norm_encoder_hidden_states, gate_ff, enc_gate_ff = self.norm2(
+ hidden_states, encoder_hidden_states, temb
+ )
+
+ # feed-forward
+ norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1)
+ ff_output = self.ff(norm_hidden_states)
+
+ hidden_states = hidden_states + gate_ff * ff_output[:, text_seq_length:]
+ encoder_hidden_states = encoder_hidden_states + enc_gate_ff * ff_output[:, :text_seq_length]
+
+ return hidden_states, encoder_hidden_states
+
+
+class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
+ """
+ A Transformer model for video-like data in [CogVideoX](https://github.com/THUDM/CogVideo).
+
+ Parameters:
+ num_attention_heads (`int`, defaults to `30`):
+ The number of heads to use for multi-head attention.
+ attention_head_dim (`int`, defaults to `64`):
+ The number of channels in each head.
+ in_channels (`int`, defaults to `16`):
+ The number of channels in the input.
+ out_channels (`int`, *optional*, defaults to `16`):
+ The number of channels in the output.
+ flip_sin_to_cos (`bool`, defaults to `True`):
+ Whether to flip the sin to cos in the time embedding.
+ time_embed_dim (`int`, defaults to `512`):
+ Output dimension of timestep embeddings.
+ text_embed_dim (`int`, defaults to `4096`):
+ Input dimension of text embeddings from the text encoder.
+ num_layers (`int`, defaults to `30`):
+ The number of layers of Transformer blocks to use.
+ dropout (`float`, defaults to `0.0`):
+ The dropout probability to use.
+ attention_bias (`bool`, defaults to `True`):
+ Whether or not to use bias in the attention projection layers.
+ sample_width (`int`, defaults to `90`):
+ The width of the input latents.
+ sample_height (`int`, defaults to `60`):
+ The height of the input latents.
+ sample_frames (`int`, defaults to `49`):
+ The number of frames in the input latents. Note that this parameter was incorrectly initialized to 49
+ instead of 13 because CogVideoX processed 13 latent frames at once in its default and recommended settings,
+ but cannot be changed to the correct value to ensure backwards compatibility. To create a transformer with
+ K latent frames, the correct value to pass here would be: ((K - 1) * temporal_compression_ratio + 1).
+ patch_size (`int`, defaults to `2`):
+ The size of the patches to use in the patch embedding layer.
+ temporal_compression_ratio (`int`, defaults to `4`):
+ The compression ratio across the temporal dimension. See documentation for `sample_frames`.
+ max_text_seq_length (`int`, defaults to `226`):
+ The maximum sequence length of the input text embeddings.
+ activation_fn (`str`, defaults to `"gelu-approximate"`):
+ Activation function to use in feed-forward.
+ timestep_activation_fn (`str`, defaults to `"silu"`):
+ Activation function to use when generating the timestep embeddings.
+ norm_elementwise_affine (`bool`, defaults to `True`):
+ Whether or not to use elementwise affine in normalization layers.
+ norm_eps (`float`, defaults to `1e-5`):
+ The epsilon value to use in normalization layers.
+ spatial_interpolation_scale (`float`, defaults to `1.875`):
+ Scaling factor to apply in 3D positional embeddings across spatial dimensions.
+ temporal_interpolation_scale (`float`, defaults to `1.0`):
+ Scaling factor to apply in 3D positional embeddings across temporal dimensions.
+ """
+
+ _supports_gradient_checkpointing = True
+
+ @register_to_config
+ def __init__(
+ self,
+ num_attention_heads: int = 30,
+ attention_head_dim: int = 64,
+ in_channels: int = 16,
+ out_channels: Optional[int] = 16,
+ flip_sin_to_cos: bool = True,
+ freq_shift: int = 0,
+ time_embed_dim: int = 512,
+ text_embed_dim: int = 4096,
+ num_layers: int = 30,
+ dropout: float = 0.0,
+ attention_bias: bool = True,
+ sample_width: int = 90,
+ sample_height: int = 60,
+ sample_frames: int = 49,
+ patch_size: int = 2,
+ temporal_compression_ratio: int = 4,
+ max_text_seq_length: int = 226,
+ activation_fn: str = "gelu-approximate",
+ timestep_activation_fn: str = "silu",
+ norm_elementwise_affine: bool = True,
+ norm_eps: float = 1e-5,
+ spatial_interpolation_scale: float = 1.875,
+ temporal_interpolation_scale: float = 1.0,
+ use_rotary_positional_embeddings: bool = False,
+ use_learned_positional_embeddings: bool = False,
+ ):
+ super().__init__()
+ inner_dim = num_attention_heads * attention_head_dim
+
+ if not use_rotary_positional_embeddings and use_learned_positional_embeddings:
+ raise ValueError(
+ "There are no CogVideoX checkpoints available with disable rotary embeddings and learned positional "
+ "embeddings. If you're using a custom model and/or believe this should be supported, please open an "
+ "issue at https://github.com/huggingface/diffusers/issues."
+ )
+
+ # 1. Patch embedding
+ self.patch_embed = CogVideoXPatchEmbed(
+ patch_size=patch_size,
+ in_channels=in_channels,
+ embed_dim=inner_dim,
+ text_embed_dim=text_embed_dim,
+ bias=True,
+ sample_width=sample_width,
+ sample_height=sample_height,
+ sample_frames=sample_frames,
+ temporal_compression_ratio=temporal_compression_ratio,
+ max_text_seq_length=max_text_seq_length,
+ spatial_interpolation_scale=spatial_interpolation_scale,
+ temporal_interpolation_scale=temporal_interpolation_scale,
+ use_positional_embeddings=not use_rotary_positional_embeddings,
+ use_learned_positional_embeddings=use_learned_positional_embeddings,
+ )
+ self.embedding_dropout = nn.Dropout(dropout)
+
+ # 2. Time embeddings
+ self.time_proj = Timesteps(inner_dim, flip_sin_to_cos, freq_shift)
+ self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, timestep_activation_fn)
+
+ # 3. Define spatio-temporal transformers blocks
+ self.transformer_blocks = nn.ModuleList(
+ [
+ CogVideoXBlock(
+ dim=inner_dim,
+ num_attention_heads=num_attention_heads,
+ attention_head_dim=attention_head_dim,
+ time_embed_dim=time_embed_dim,
+ dropout=dropout,
+ activation_fn=activation_fn,
+ attention_bias=attention_bias,
+ norm_elementwise_affine=norm_elementwise_affine,
+ norm_eps=norm_eps,
+ )
+ for _ in range(num_layers)
+ ]
+ )
+ self.norm_final = nn.LayerNorm(inner_dim, norm_eps, norm_elementwise_affine)
+
+ # 4. Output blocks
+ self.norm_out = AdaLayerNorm(
+ embedding_dim=time_embed_dim,
+ output_dim=2 * inner_dim,
+ norm_elementwise_affine=norm_elementwise_affine,
+ norm_eps=norm_eps,
+ chunk_dim=1,
+ )
+ self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels)
+
+ self.gradient_checkpointing = False
+
+
+ def _set_gradient_checkpointing(self, module, value=False):
+ self.gradient_checkpointing = value
+
+ @property
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
+ r"""
+ Returns:
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
+ indexed by its weight name.
+ """
+ # set recursively
+ processors = {}
+
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
+ if hasattr(module, "get_processor"):
+ processors[f"{name}.processor"] = module.get_processor()
+
+ for sub_name, child in module.named_children():
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
+
+ return processors
+
+ for name, module in self.named_children():
+ fn_recursive_add_processors(name, module, processors)
+
+ return processors
+
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
+ r"""
+ Sets the attention processor to use to compute attention.
+
+ Parameters:
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
+ for **all** `Attention` layers.
+
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
+ processor. This is strongly recommended when setting trainable attention processors.
+
+ """
+ count = len(self.attn_processors.keys())
+
+ if isinstance(processor, dict) and len(processor) != count:
+ raise ValueError(
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
+ )
+
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
+ if hasattr(module, "set_processor"):
+ if not isinstance(processor, dict):
+ module.set_processor(processor)
+ else:
+ module.set_processor(processor.pop(f"{name}.processor"))
+
+ for sub_name, child in module.named_children():
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
+
+ for name, module in self.named_children():
+ fn_recursive_attn_processor(name, module, processor)
+
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedCogVideoXAttnProcessor2_0
+ def fuse_qkv_projections(self):
+ """
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
+ are fused. For cross-attention modules, key and value projection matrices are fused.
+
+
+
+ This API is 🧪 experimental.
+
+
+ """
+ self.original_attn_processors = None
+
+ for _, attn_processor in self.attn_processors.items():
+ if "Added" in str(attn_processor.__class__.__name__):
+ raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
+
+ self.original_attn_processors = self.attn_processors
+
+ for module in self.modules():
+ if isinstance(module, Attention):
+ module.fuse_projections(fuse=True)
+
+ self.set_attn_processor(FusedCogVideoXAttnProcessor2_0())
+
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
+ def unfuse_qkv_projections(self):
+ """Disables the fused QKV projection if enabled.
+
+
+
+ This API is 🧪 experimental.
+
+
+
+ """
+ if self.original_attn_processors is not None:
+ self.set_attn_processor(self.original_attn_processors)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ timestep: Union[int, float, torch.LongTensor],
+ timestep_cond: Optional[torch.Tensor] = None,
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ return_dict: bool = True,
+ ):
+ if attention_kwargs is not None:
+ attention_kwargs = attention_kwargs.copy()
+ lora_scale = attention_kwargs.pop("scale", 1.0)
+ else:
+ lora_scale = 1.0
+
+ if USE_PEFT_BACKEND:
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
+ scale_lora_layers(self, lora_scale)
+ else:
+ if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
+ logger.warning(
+ "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
+ )
+
+ batch_size, num_frames, channels, height, width = hidden_states.shape
+ # 1. Time embedding
+ timesteps = timestep
+ t_emb = self.time_proj(timesteps)
+
+ # timesteps does not contain any weights and will always return f32 tensors
+ # but time_embedding might actually be running in fp16. so we need to cast here.
+ # there might be better ways to encapsulate this.
+ t_emb = t_emb.to(dtype=hidden_states.dtype)
+ emb = self.time_embedding(t_emb, timestep_cond)
+
+ # 2. Patch embedding
+ hidden_states = self.patch_embed(encoder_hidden_states, hidden_states)
+ hidden_states = self.embedding_dropout(hidden_states)
+
+ text_seq_length = encoder_hidden_states.shape[1]
+ encoder_hidden_states = hidden_states[:, :text_seq_length]
+ hidden_states = hidden_states[:, text_seq_length:]
+
+ # 3. Transformer blocks
+ for i, block in enumerate(self.transformer_blocks):
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs)
+
+ return custom_forward
+
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
+ hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(block),
+ hidden_states,
+ encoder_hidden_states,
+ emb,
+ image_rotary_emb,
+ **ckpt_kwargs,
+ )
+ else:
+ hidden_states, encoder_hidden_states = block(
+ hidden_states=hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ temb=emb,
+ image_rotary_emb=image_rotary_emb,
+ )
+
+ if not self.config.use_rotary_positional_embeddings:
+ # CogVideoX-2B
+ hidden_states = self.norm_final(hidden_states)
+ else:
+ # CogVideoX-5B
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
+ hidden_states = self.norm_final(hidden_states)
+ hidden_states = hidden_states[:, text_seq_length:]
+
+ # 4. Final block
+ hidden_states = self.norm_out(hidden_states, temb=emb)
+ hidden_states = self.proj_out(hidden_states)
+
+ # 5. Unpatchify
+ # Note: we use `-1` instead of `channels`:
+ # - It is okay to `channels` use for CogVideoX-2b and CogVideoX-5b (number of input channels is equal to output channels)
+ # - However, for CogVideoX-5b-I2V also takes concatenated input image latents (number of input channels is twice the output channels)
+ p = self.config.patch_size
+ output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, -1, p, p)
+ output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
+
+ if USE_PEFT_BACKEND:
+ # remove `lora_scale` from each PEFT layer
+ unscale_lora_layers(self, lora_scale)
+
+ if not return_dict:
+ return (output,)
+ return Transformer2DModelOutput(sample=output)
diff --git a/MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/cogvideox_5b/pipelines/__init__.py b/MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/cogvideox_5b/pipelines/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..1032118c1ea954fc485d79866e9d821c6352fba0
--- /dev/null
+++ b/MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/cogvideox_5b/pipelines/__init__.py
@@ -0,0 +1 @@
+from .pipeline_cogvideox import CogVideoXPipeline
diff --git a/MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/cogvideox_5b/pipelines/pipeline_cogvideox.py b/MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/cogvideox_5b/pipelines/pipeline_cogvideox.py
new file mode 100644
index 0000000000000000000000000000000000000000..224e9d9b6a7f6350bbe13f0d3a7b8d6aa1098cd0
--- /dev/null
+++ b/MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/cogvideox_5b/pipelines/pipeline_cogvideox.py
@@ -0,0 +1,760 @@
+# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
+# All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+import math
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+import torch
+from transformers import T5EncoderModel, T5Tokenizer
+from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
+from diffusers.loaders import CogVideoXLoraLoaderMixin
+from diffusers.models import AutoencoderKLCogVideoX
+from diffusers.pipelines.pipeline_utils import DiffusionPipeline
+from diffusers.schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler
+from diffusers.utils import logging, replace_example_docstring
+from diffusers.utils.torch_utils import randn_tensor
+from diffusers.video_processor import VideoProcessor
+from ..models import CogVideoXTransformer3DModel
+from ..models.embeddings import get_3d_rotary_pos_embed
+from .pipeline_output import CogVideoXPipelineOutput
+from ..utils.parallel_state import get_world_size, get_rank, all_gather
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```python
+ >>> import torch
+ >>> from diffusers import CogVideoXPipeline
+ >>> from diffusers.utils import export_to_video
+
+ >>> # Models: "THUDM/CogVideoX-2b" or "THUDM/CogVideoX-5b"
+ >>> pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-2b", torch_dtype=torch.float16).to("cuda")
+ >>> prompt = (
+ ... "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. "
+ ... "The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other "
+ ... "pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, "
+ ... "casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. "
+ ... "The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical "
+ ... "atmosphere of this unique musical performance."
+ ... )
+ >>> video = pipe(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0]
+ >>> export_to_video(video, "output.mp4", fps=8)
+ ```
+"""
+
+# Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid
+def get_resize_crop_region_for_grid(src, tgt_width, tgt_height):
+ tw = tgt_width
+ th = tgt_height
+ h, w = src
+ r = h / w
+ if r > (th / tw):
+ resize_height = th
+ resize_width = int(round(th / h * w))
+ else:
+ resize_width = tw
+ resize_height = int(round(tw / w * h))
+
+ crop_top = int(round((th - resize_height) / 2.0))
+ crop_left = int(round((tw - resize_width) / 2.0))
+
+ return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ sigmas: Optional[List[float]] = None,
+ **kwargs,
+):
+ r"""
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
+ `num_inference_steps` and `sigmas` must be `None`.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
+ `num_inference_steps` and `timesteps` must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None and sigmas is not None:
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ elif sigmas is not None:
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accept_sigmas:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
+ r"""
+ Pipeline for text-to-video generation using CogVideoX.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
+ text_encoder ([`T5EncoderModel`]):
+ Frozen text-encoder. CogVideoX uses
+ [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel); specifically the
+ [t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant.
+ tokenizer (`T5Tokenizer`):
+ Tokenizer of class
+ [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
+ transformer ([`CogVideoXTransformer3DModel`]):
+ A text conditioned `CogVideoXTransformer3DModel` to denoise the encoded video latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `transformer` to denoise the encoded video latents.
+ """
+
+ _optional_components = []
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
+
+ _callback_tensor_inputs = [
+ "latents",
+ "prompt_embeds",
+ "negative_prompt_embeds",
+ ]
+
+ def __init__(
+ self,
+ tokenizer: T5Tokenizer,
+ text_encoder: T5EncoderModel,
+ vae: AutoencoderKLCogVideoX,
+ transformer: CogVideoXTransformer3DModel,
+ scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler],
+ ):
+ super().__init__()
+
+ self.register_modules(
+ tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
+ )
+ self.vae_scale_factor_spatial = (
+ 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
+ )
+ self.vae_scale_factor_temporal = (
+ self.vae.config.temporal_compression_ratio if hasattr(self, "vae") and self.vae is not None else 4
+ )
+ self.vae_scaling_factor_image = (
+ self.vae.config.scaling_factor if hasattr(self, "vae") and self.vae is not None else 0.7
+ )
+
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
+
+ def _get_t5_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]] = None,
+ num_videos_per_prompt: int = 1,
+ max_sequence_length: int = 226,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ device = device or self._execution_device
+ dtype = dtype or self.text_encoder.dtype
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ batch_size = len(prompt)
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=max_sequence_length,
+ truncation=True,
+ add_special_tokens=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because `max_sequence_length` is set to "
+ f" {max_sequence_length} tokens: {removed_text}"
+ )
+
+ prompt_embeds = self.text_encoder(text_input_ids.to(device))[0]
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ _, seq_len, _ = prompt_embeds.shape
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
+
+ return prompt_embeds
+
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ do_classifier_free_guidance: bool = True,
+ num_videos_per_prompt: int = 1,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ max_sequence_length: int = 226,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
+ Whether to use classifier free guidance or not.
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
+ Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ device: (`torch.device`, *optional*):
+ torch device
+ dtype: (`torch.dtype`, *optional*):
+ torch dtype
+ """
+ device = device or self._execution_device
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ if prompt is not None:
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ prompt_embeds = self._get_t5_prompt_embeds(
+ prompt=prompt,
+ num_videos_per_prompt=num_videos_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ dtype=dtype,
+ )
+
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ negative_prompt = negative_prompt or ""
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
+
+ if prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+
+ negative_prompt_embeds = self._get_t5_prompt_embeds(
+ prompt=negative_prompt,
+ num_videos_per_prompt=num_videos_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ dtype=dtype,
+ )
+
+ return prompt_embeds, negative_prompt_embeds
+
+ def prepare_latents(
+ self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
+ ):
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ shape = (
+ batch_size,
+ (num_frames - 1) // self.vae_scale_factor_temporal + 1,
+ num_channels_latents,
+ height // self.vae_scale_factor_spatial,
+ width // self.vae_scale_factor_spatial,
+ )
+
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ latents = latents.to(device)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+ return latents
+
+ def decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
+ latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width]
+ latents = 1 / self.vae_scaling_factor_image * latents
+
+ frames = self.vae.decode(latents).sample
+ return frames
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ # Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs
+ def check_inputs(
+ self,
+ prompt,
+ height,
+ width,
+ negative_prompt,
+ callback_on_step_end_tensor_inputs,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ ):
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
+ def fuse_qkv_projections(self) -> None:
+ r"""Enables fused QKV projections."""
+ self.fusing_transformer = True
+ self.transformer.fuse_qkv_projections()
+
+ def unfuse_qkv_projections(self) -> None:
+ r"""Disable QKV projection fusion if enabled."""
+ if not self.fusing_transformer:
+ logger.warning("The Transformer was not initially fused for QKV projections. Doing nothing.")
+ else:
+ self.transformer.unfuse_qkv_projections()
+ self.fusing_transformer = False
+
+ def _prepare_rotary_positional_embeddings(
+ self,
+ height: int,
+ width: int,
+ num_frames: int,
+ device: torch.device,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
+ grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
+ base_size_width = 720 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) # 720/8/2
+ base_size_height = 480 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) # 480/8/2
+
+ grid_crops_coords = get_resize_crop_region_for_grid(
+ (grid_height, grid_width), base_size_width, base_size_height
+ )
+ freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
+ embed_dim=self.transformer.config.attention_head_dim,
+ crops_coords=grid_crops_coords,
+ grid_size=(grid_height, grid_width),
+ temporal_size=num_frames,
+ )
+
+ freqs_cos = freqs_cos.to(device=device)
+ freqs_sin = freqs_sin.to(device=device)
+
+ return freqs_cos, freqs_sin
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def attention_kwargs(self):
+ return self._attention_kwargs
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Optional[Union[str, List[str]]] = None,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ height: int = 480,
+ width: int = 720,
+ num_frames: int = 49,
+ num_inference_steps: int = 50,
+ timesteps: Optional[List[int]] = None,
+ guidance_scale: float = 6,
+ use_dynamic_cfg: bool = False,
+ num_videos_per_prompt: int = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ output_type: str = "pil",
+ return_dict: bool = True,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ callback_on_step_end: Optional[
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
+ ] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 226,
+ ) -> Union[CogVideoXPipelineOutput, Tuple]:
+ """
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ height (`int`, *optional*, defaults to self.transformer.config.sample_height * self.vae_scale_factor_spatial):
+ The height in pixels of the generated image. This is set to 480 by default for the best results.
+ width (`int`, *optional*, defaults to self.transformer.config.sample_height * self.vae_scale_factor_spatial):
+ The width in pixels of the generated image. This is set to 720 by default for the best results.
+ num_frames (`int`, defaults to `48`):
+ Number of frames to generate. Must be divisible by self.vae_scale_factor_temporal. Generated video will
+ contain 1 extra frame because CogVideoX is conditioned with (num_seconds * fps + 1) frames where
+ num_seconds is 6 and fps is 8. However, since videos can be saved at any fps, the only condition that
+ needs to be satisfied is that of divisibility mentioned above.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
+ passed will be used. Must be in descending order.
+ guidance_scale (`float`, *optional*, defaults to 7.0):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
+ The number of videos to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
+ of a plain tuple.
+ attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ callback_on_step_end (`Callable`, *optional*):
+ A function that calls at the end of each denoising steps during the inference. The function is called
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
+ `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+ max_sequence_length (`int`, defaults to `226`):
+ Maximum sequence length in encoded prompt. Must be consistent with
+ `self.transformer.config.max_text_seq_length` otherwise may lead to poor results.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipelineOutput`] or `tuple`:
+ [`~pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipelineOutput`] if `return_dict` is True, otherwise a
+ `tuple`. When returning a tuple, the first element is a list with the generated images.
+ """
+
+ if num_frames > 49:
+ raise ValueError(
+ "The number of frames must be less than 49 for now due to static positional embeddings. This will be updated in the future to remove this limitation."
+ )
+
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
+
+ num_videos_per_prompt = 1
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ height,
+ width,
+ negative_prompt,
+ callback_on_step_end_tensor_inputs,
+ prompt_embeds,
+ negative_prompt_embeds,
+ )
+ self._guidance_scale = guidance_scale
+ self._attention_kwargs = attention_kwargs
+ self._interrupt = False
+
+ # 2. Default call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ # 3. Encode input prompt
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
+ prompt,
+ negative_prompt,
+ do_classifier_free_guidance,
+ num_videos_per_prompt=num_videos_per_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ )
+ if do_classifier_free_guidance:
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
+
+ # 4. Prepare timesteps
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
+ self._num_timesteps = len(timesteps)
+
+ # 5. Prepare latents.
+ latent_channels = self.transformer.config.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_videos_per_prompt,
+ latent_channels,
+ num_frames,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 7. Create rotary embeds if required
+ image_rotary_emb = (
+ self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device)
+ if self.transformer.config.use_rotary_positional_embeddings
+ else None
+ )
+
+ # 8. Denoising loop
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+
+ # p_t = self.transformer.config.patch_size_t or 1
+ latents, prompt_embeds, image_rotary_emb = self._init_sync_pipeline(
+ latents, prompt_embeds, image_rotary_emb,
+ latents.size(1)
+ )
+
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ # for DPM-solver++
+ old_pred_original_sample = None
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timestep = t.expand(latent_model_input.shape[0])
+
+ # predict noise model_output
+ noise_pred = self.transformer(
+ hidden_states=latent_model_input,
+ encoder_hidden_states=prompt_embeds,
+ timestep=timestep,
+ image_rotary_emb=image_rotary_emb,
+ attention_kwargs=attention_kwargs,
+ return_dict=False,
+ )[0]
+
+ noise_pred = noise_pred.float()
+
+ # perform guidance
+ if use_dynamic_cfg:
+ self._guidance_scale = 1 + guidance_scale * (
+ (1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2
+ )
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ if not isinstance(self.scheduler, CogVideoXDPMScheduler):
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
+ else:
+ latents, old_pred_original_sample = self.scheduler.step(
+ noise_pred,
+ old_pred_original_sample,
+ t,
+ timesteps[i - 1] if i > 0 else None,
+ latents,
+ **extra_step_kwargs,
+ return_dict=False,
+ )
+ latents = latents.to(prompt_embeds.dtype)
+
+ # call the callback, if provided
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
+
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if not output_type == "latent":
+ video = self.decode_latents(latents.half())
+ video = self.video_processor.postprocess_video(video=video, output_type=output_type)
+ else:
+ video = latents
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (video,)
+
+ return CogVideoXPipelineOutput(frames=video)
+
+ def _init_sync_pipeline(
+ self,
+ latents: torch.Tensor,
+ prompt_embeds: torch.Tensor,
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ latents_frames: Optional[int] = None,
+ ):
+ if prompt_embeds.shape[-2] % get_world_size() == 0:
+ prompt_embeds = torch.chunk(prompt_embeds, get_world_size(), dim=-2)[get_rank()]
+ return latents, prompt_embeds, image_rotary_emb
\ No newline at end of file
diff --git a/MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/cogvideox_5b/pipelines/pipeline_output.py b/MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/cogvideox_5b/pipelines/pipeline_output.py
new file mode 100644
index 0000000000000000000000000000000000000000..3de030dd6928db49ab0bc4d11868a93ac98dea50
--- /dev/null
+++ b/MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/cogvideox_5b/pipelines/pipeline_output.py
@@ -0,0 +1,20 @@
+from dataclasses import dataclass
+
+import torch
+
+from diffusers.utils import BaseOutput
+
+
+@dataclass
+class CogVideoXPipelineOutput(BaseOutput):
+ r"""
+ Output class for CogVideo pipelines.
+
+ Args:
+ frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
+ List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
+ denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
+ `(batch_size, num_frames, channels, height, width)`.
+ """
+
+ frames: torch.Tensor
diff --git a/MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/cogvideox_5b/utils/__init__.py b/MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/cogvideox_5b/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e3931ccb6267c2a21934d8d07fee0d7ed15a92d5
--- /dev/null
+++ b/MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/cogvideox_5b/utils/__init__.py
@@ -0,0 +1 @@
+from .parallel_state import get_rank, get_world_size, all_gather
diff --git a/MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/cogvideox_5b/utils/parallel_mgr.py b/MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/cogvideox_5b/utils/parallel_mgr.py
new file mode 100644
index 0000000000000000000000000000000000000000..b0bd7d2469155ec173038069ea68282ed437ffbb
--- /dev/null
+++ b/MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/cogvideox_5b/utils/parallel_mgr.py
@@ -0,0 +1,36 @@
+import os
+import torch
+import torch_npu
+import torch.distributed as dist
+from torch_npu._C._distributed_c10d import ProcessGroupHCCL
+
+
+class ParallelManager:
+ def __init__(self):
+ local_rank = int(os.environ.get("LOCAL_RANK","0"))
+ world_size = int(os.environ.get("WORLD_SIZE","1"))
+ self.rank = local_rank
+ self.world_size = world_size
+ if self.world_size > 1:
+ self.init_group()
+
+
+ def init_group(self):
+ device = torch.device(f"npu:{self.rank}")
+ torch.npu.set_device(device)
+
+ backend = "hccl"
+ options = ProcessGroupHCCL.Options()
+ print("ProcessGroupHCCL has been Set")
+ if not torch.distributed.is_initialized():
+ # Call the init process.
+ torch.distributed.init_process_group(
+ backend=backend,
+ world_size=self.world_size,
+ rank=self.rank,
+ pg_options=options,
+ )
+ print(f"rank {self.rank} init {torch.distributed.is_initialized()}, init_process_group has been activated")
+ else:
+ print("torch.distributed is already initialized.")
+
diff --git a/MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/cogvideox_5b/utils/parallel_state.py b/MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/cogvideox_5b/utils/parallel_state.py
new file mode 100644
index 0000000000000000000000000000000000000000..0fbae757326804391523a185c0f06f60805bf60c
--- /dev/null
+++ b/MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/cogvideox_5b/utils/parallel_state.py
@@ -0,0 +1,53 @@
+import torch
+from typing import Any, Dict, List, Optional, Tuple, Union
+from .parallel_mgr import ParallelManager
+
+mgr = ParallelManager()
+
+def get_world_size():
+ return mgr.world_size
+
+def get_rank():
+ return mgr.rank
+
+def all_gather(input_: torch.Tensor, dim: int = 0, separate_tensors: bool = False
+ ) -> Union[torch.Tensor, List[torch.Tensor]]:
+ world_size = get_world_size()
+ # Bypass the function if we are using only 1 GPU.
+ if world_size == 1:
+ return input_
+ assert (
+ -input_.dim() <= dim < input_.dim()
+ ), f"Invalid dim ({dim}) for input tensor with shape {input_.size()}"
+ if dim < 0:
+ # Convert negative dim to positive.
+ dim += input_.dim()
+ # Allocate output tensor.
+ input_size = list(input_.size())
+ input_size[0] *= world_size
+ output_tensor = torch.empty(
+ input_size, dtype=input_.dtype, device=input_.device
+ )
+ # All-gather.
+ torch.distributed.all_gather_into_tensor(
+ output_tensor, input_
+ )
+ if dim != 0:
+ input_size[0] //= world_size
+ output_tensor = output_tensor.reshape([world_size, ] + input_size)
+ output_tensor = output_tensor.movedim(0, dim)
+
+ if separate_tensors:
+ tensor_list = [
+ output_tensor.view(-1)
+ .narrow(0, input_.numel() * i, input_.numel())
+ .view_as(input_)
+ for i in range(world_size)
+ ]
+ return tensor_list
+ else:
+ input_size = list(input_.size())
+ input_size[dim] = input_size[dim] * world_size
+ # Reshape
+ output_tensor = output_tensor.reshape(input_size)
+ return output_tensor
diff --git a/MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/inference.py b/MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..03ae7b1c7d820c504863b79cd4f82db569de15d5
--- /dev/null
+++ b/MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/inference.py
@@ -0,0 +1,205 @@
+import os
+import argparse
+import time
+import torch
+import torch_npu
+import functools
+from typing import List, Optional, Tuple, Union, Literal
+from cogvideox_5b import CogVideoXPipeline, CogVideoXTransformer3DModel, get_rank, get_world_size, all_gather
+from diffusers import CogVideoXDPMScheduler
+from diffusers.utils import export_to_video
+from torch_npu.contrib import transfer_to_npu
+
+
+def parallelize_transformer(pipe):
+ transformer = pipe.transformer
+ original_forward = transformer.forward
+
+ @functools.wraps(transformer.__class__.forward)
+ def new_forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ timestep: torch.LongTensor = None,
+ timestep_cond: Optional[torch.Tensor] = None,
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ **kwargs,
+ ):
+ temporal_size = hidden_states.shape[1]
+ n, c, t, h, w = hidden_states.shape
+ hidden_states = torch.cat([hidden_states, torch.zeros(n, c, t, 4, w, device=hidden_states.device, dtype=hidden_states.dtype)], dim=-2)
+ hidden_states = torch.chunk(hidden_states, get_world_size(), dim=-2)[get_rank()]
+ if image_rotary_emb is not None:
+ freqs_cos, freqs_sin = image_rotary_emb
+
+ def get_rotary_emb_chunk(freqs):
+ dim_thw = freqs.shape[-1]
+ freqs = freqs.reshape(temporal_size, -1, dim_thw)
+ freqs = freqs.reshape(temporal_size,-1,45,dim_thw)
+ freqs = torch.cat([freqs, torch.zeros(temporal_size, 2, 45, dim_thw, device=freqs.device, dtype=freqs.dtype)], dim=1)
+ freqs = freqs.reshape(temporal_size, -1, dim_thw)
+ freqs = torch.chunk(freqs, get_world_size(), dim=-2)[get_rank()]
+ freqs = freqs.reshape(-1, dim_thw)
+ return freqs
+
+ freqs_cos = get_rotary_emb_chunk(freqs_cos)
+ freqs_sin = get_rotary_emb_chunk(freqs_sin)
+ image_rotary_emb = (freqs_cos, freqs_sin)
+
+ output = original_forward(
+ hidden_states,
+ encoder_hidden_states,
+ timestep=timestep,
+ timestep_cond=timestep_cond,
+ image_rotary_emb=image_rotary_emb,
+ **kwargs,
+ )
+
+ return_dict = not isinstance(output, tuple)
+ sample = output[0]
+ sample = all_gather(sample, dim=-2)
+ sample = sample[:, :, :, :-4, :]
+ if return_dict:
+ return output.__class__(sample, *output[1:])
+ return (sample, *output[1:])
+
+ new_forward = new_forward.__get__(transformer)
+ transformer.forward = new_forward
+
+ original_patch_embed_forward = transformer.patch_embed.forward
+
+ @functools.wraps(transformer.patch_embed.__class__.forward)
+ def new_patch_embed(
+ self, text_embeds: torch.Tensor, image_embeds: torch.Tensor
+ ):
+ text_embeds = all_gather(text_embeds.contiguous(), dim=-2)
+ image_embeds = all_gather(image_embeds.contiguous(), dim=-2)
+ batch, num_frames, channels, height, width = image_embeds.shape
+ text_len = text_embeds.shape[-2]
+ output = original_patch_embed_forward(text_embeds, image_embeds)
+ text_embeds = output[:, :text_len, :]
+ image_embeds = output[:, text_len:, :].reshape(batch, num_frames, -1, output.shape[-1])
+
+ text_embeds = torch.chunk(text_embeds, get_world_size(),dim=-2)[get_rank()]
+ image_embeds = torch.chunk(image_embeds, get_world_size(),dim=-2)[get_rank()]
+ image_embeds = image_embeds.reshape(batch, -1, image_embeds.shape[-1])
+ return torch.cat([text_embeds, image_embeds], dim=1)
+
+ new_patch_embed = new_patch_embed.__get__(transformer.patch_embed)
+ transformer.patch_embed.forward = new_patch_embed
+
+
+def generate_video(
+ prompt: str,
+ model_path: str,
+ lora_path: str = None,
+ lora_rank: int = 128,
+ num_frames: int = 81,
+ width: int = 1360,
+ height: int = 768,
+ output_path: str = "./output.mp4",
+ image_or_video_path: str = "",
+ num_inference_steps: int = 50,
+ guidance_scale: float = 6.0,
+ num_videos_per_prompt: int = 1,
+ dtype: torch.dtype = torch.bfloat16,
+ generate_type: str = Literal["t2v", "i2v", "v2v"], # i2v: image to video, v2v: video to video
+ seed: int = 42,
+ fps: int = 8
+):
+ pipe = CogVideoXPipeline.from_pretrained(model_path, torch_dtype=dtype).to(f"npu:{get_rank()}")
+ transformer = CogVideoXTransformer3DModel.from_pretrained(os.path.join(model_path,'transformer'), torch_dtype=dtype).to(f"npu:{get_rank()}")
+ if lora_path:
+ pipe.load_lora_weights(lora_path, weight_name="pytorch_lora_weights.safetensors", adapter_name="test_1")
+ pipe.fuse_lora(lora_scale=1 / lora_rank)
+ pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
+ pipe.transformer = transformer
+ pipe.vae = pipe.vae.half()
+ pipe.vae.enable_slicing()
+ pipe.vae.enable_tiling()
+ if get_world_size()>1:
+ parallelize_transformer(pipe)
+
+ # warm up
+ video_generate = pipe(
+ height=height,
+ width=width,
+ prompt=prompt,
+ num_videos_per_prompt=num_videos_per_prompt,
+ num_inference_steps=1,
+ num_frames=num_frames,
+ use_dynamic_cfg=True,
+ guidance_scale=guidance_scale,
+ generator=torch.Generator().manual_seed(seed),
+ output_type="pil"
+ ).frames[0]
+
+ torch_npu.npu.synchronize()
+ start = time.time()
+ video_generate = pipe(
+ height=height,
+ width=width,
+ prompt=prompt,
+ num_videos_per_prompt=num_videos_per_prompt,
+ num_inference_steps=num_inference_steps,
+ num_frames=num_frames,
+ use_dynamic_cfg=True,
+ guidance_scale=guidance_scale,
+ generator=torch.Generator().manual_seed(seed),
+ output_type="pil"
+ ).frames[0]
+ torch_npu.npu.synchronize()
+ end = time.time()
+ print(f"Time taken for inference: {end - start} seconds")
+
+ export_to_video(video_generate, output_path, fps=fps)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description="Generate a video from a text prompt using CogVideoX")
+ parser.add_argument("--prompt", type=str, required=True, help="The description of the video to be generated")
+ parser.add_argument(
+ "--image_or_video_path",
+ type=str,
+ default=None,
+ help="The path of the image to be used as the background of the video",
+ )
+ parser.add_argument(
+ "--model_path", type=str, default="/data/CogVideoX-5b", help="Path of the pre-trained model use"
+ )
+ parser.add_argument("--lora_path", type=str, default=None, help="The path of the LoRA weights to be used")
+ parser.add_argument("--lora_rank", type=int, default=128, help="The rank of the LoRA weights")
+ parser.add_argument("--output_path", type=str, default="./output.mp4", help="The path save generated video")
+ parser.add_argument("--guidance_scale", type=float, default=6.0, help="The scale for classifier-free guidance")
+ parser.add_argument("--num_inference_steps", type=int, default=50, help="Inference steps")
+ parser.add_argument("--num_frames", type=int, default=48, help="Number of steps for the inference process")
+ parser.add_argument("--width", type=int, default=720, help="Number of steps for the inference process")
+ parser.add_argument("--height", type=int, default=480, help="Number of steps for the inference process")
+ parser.add_argument("--fps", type=int, default=8, help="Number of steps for the inference process")
+ parser.add_argument("--num_videos_per_prompt", type=int, default=1, help="Number of videos to generate per prompt")
+ parser.add_argument("--generate_type", type=str, default="t2v", help="The type of video generation")
+ parser.add_argument("--dtype", type=str, default="bfloat16", help="The data type for computation")
+ parser.add_argument("--seed", type=int, default=42, help="The seed for reproducibility")
+
+ args = parser.parse_args()
+ dtype = torch.float16 if args.dtype == "float16" else torch.bfloat16
+ torch.npu.config.allow_internal_format = False
+ generate_video(
+ prompt=args.prompt,
+ model_path=args.model_path,
+ lora_path=args.lora_path,
+ lora_rank=args.lora_rank,
+ output_path=args.output_path,
+ num_frames=args.num_frames,
+ width=args.width,
+ height=args.height,
+ image_or_video_path=args.image_or_video_path,
+ num_inference_steps=args.num_inference_steps,
+ guidance_scale=args.guidance_scale,
+ num_videos_per_prompt=args.num_videos_per_prompt,
+ dtype=dtype,
+ generate_type=args.generate_type,
+ seed=args.seed,
+ fps=args.fps,
+ )
+
diff --git a/MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/requirements.txt b/MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..30086550650da8c1797e8c1907781b9c91900703
--- /dev/null
+++ b/MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/requirements.txt
@@ -0,0 +1,14 @@
+diffusers>=0.31.0
+accelerate>=1.1.1
+transformers>=4.46.2
+numpy==1.26.0
+torch>=2.5.0
+torchvision>=0.20.0
+sentencepiece>=0.2.0
+SwissArmyTransformer>=0.4.12
+gradio>=5.5.0
+imageio>=2.35.1
+imageio-ffmpeg>=0.5.1
+openai>=1.54.0
+moviepy>=1.0.3
+scikit-video>=1.1.11
\ No newline at end of file