From e008ef6b8cdf088fcbe041271a73ddf3f4993fc5 Mon Sep 17 00:00:00 2001 From: taoyuan-guo Date: Mon, 24 Feb 2025 17:06:04 +0800 Subject: [PATCH 1/2] =?UTF-8?q?hunyuanvideo=E5=88=9D=E7=89=88=E4=BB=A3?= =?UTF-8?q?=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../hyvideo/vae/mod_attention.py | 107 ++++++++++++++++++ 1 file changed, 107 insertions(+) create mode 100644 MindIE/MindIE-Torch/built-in/foundation/hunyuan_video/hyvideo/vae/mod_attention.py diff --git a/MindIE/MindIE-Torch/built-in/foundation/hunyuan_video/hyvideo/vae/mod_attention.py b/MindIE/MindIE-Torch/built-in/foundation/hunyuan_video/hyvideo/vae/mod_attention.py new file mode 100644 index 0000000000..4a7b9c18f7 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/hunyuan_video/hyvideo/vae/mod_attention.py @@ -0,0 +1,107 @@ +from typing import Optional +import torch +import torch.nn.functional as F +import math +import torch_npu +from diffusers.models.attention_processor import Attention +MAX_TOKEN = 2147483647 + + +class AttnProcessor2_0: + r""" + Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). + """ + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + temb: Optional[torch.Tensor] = None, + *args, + **kwargs, + ) -> torch.Tensor: + + residual = hidden_states + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + scale = 1.0 / math.sqrt(head_dim) + hidden_states = torch_npu.npu_fusion_attention( + query, key, value, + head_num=attn.heads, + input_layout="BNSD", + scale=scale, + pse=None, + atten_mask=attention_mask, + pre_tockens=MAX_TOKEN, + next_tockens=MAX_TOKEN, + keep_prob=1.0, + sync=False + )[0] + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states \ No newline at end of file -- Gitee From 255235bcf99ab54a01d5cb37550a7637d33571a3 Mon Sep 17 00:00:00 2001 From: taoyuan-guo Date: Mon, 24 Feb 2025 18:37:50 +0800 Subject: [PATCH 2/2] =?UTF-8?q?hunyuanvideo=E5=88=9D=E7=89=88=E4=BB=A3?= =?UTF-8?q?=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../foundation/hunyuan_video/hyvideo/vae/mod_attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/MindIE/MindIE-Torch/built-in/foundation/hunyuan_video/hyvideo/vae/mod_attention.py b/MindIE/MindIE-Torch/built-in/foundation/hunyuan_video/hyvideo/vae/mod_attention.py index 4a7b9c18f7..b53e6c8a11 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/hunyuan_video/hyvideo/vae/mod_attention.py +++ b/MindIE/MindIE-Torch/built-in/foundation/hunyuan_video/hyvideo/vae/mod_attention.py @@ -1,7 +1,7 @@ from typing import Optional +import math import torch import torch.nn.functional as F -import math import torch_npu from diffusers.models.attention_processor import Attention MAX_TOKEN = 2147483647 -- Gitee