From 44418929a67c32588b37f58c5bae15604f6450c6 Mon Sep 17 00:00:00 2001 From: taoyuan-guo Date: Fri, 21 Feb 2025 12:57:44 +0800 Subject: [PATCH 1/5] hunyuanvideo --- .../hyvideo/vae/autoencoder_kl_causal_3d.py | 603 ++++++++++++++++++ 1 file changed, 603 insertions(+) create mode 100644 MindIE/MindIE-Torch/built-in/foundation/hunyuan_video/hyvideo/vae/autoencoder_kl_causal_3d.py diff --git a/MindIE/MindIE-Torch/built-in/foundation/hunyuan_video/hyvideo/vae/autoencoder_kl_causal_3d.py b/MindIE/MindIE-Torch/built-in/foundation/hunyuan_video/hyvideo/vae/autoencoder_kl_causal_3d.py new file mode 100644 index 0000000000..e635884e2f --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/hunyuan_video/hyvideo/vae/autoencoder_kl_causal_3d.py @@ -0,0 +1,603 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# +# Modified from diffusers==0.29.2 +# +# ============================================================================== +from typing import Dict, Optional, Tuple, Union +from dataclasses import dataclass + +import torch +import torch.nn as nn + +from diffusers.configuration_utils import ConfigMixin, register_to_config + +try: + # This diffusers is modified and packed in the mirror. + from diffusers.loaders import FromOriginalVAEMixin +except ImportError: + # Use this to be compatible with the original diffusers. + from diffusers.loaders.single_file_model import FromOriginalModelMixin as FromOriginalVAEMixin +from diffusers.utils.accelerate_utils import apply_forward_hook +from diffusers.models.attention_processor import ( + ADDED_KV_ATTENTION_PROCESSORS, + CROSS_ATTENTION_PROCESSORS, + Attention, + AttentionProcessor, + AttnAddedKVProcessor, + AttnProcessor, +) +from diffusers.models.modeling_outputs import AutoencoderKLOutput +from diffusers.models.modeling_utils import ModelMixin +from .vae import DecoderCausal3D, BaseOutput, DecoderOutput, DiagonalGaussianDistribution, EncoderCausal3D + + +@dataclass +class DecoderOutput2(BaseOutput): + sample: torch.FloatTensor + posterior: Optional[DiagonalGaussianDistribution] = None + + +class AutoencoderKLCausal3D(ModelMixin, ConfigMixin, FromOriginalVAEMixin): + r""" + A VAE model with KL loss for encoding images/videos into latents and decoding latent representations into images/videos. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented + for all models (such as downloading or saving). + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + down_block_types: Tuple[str] = ("DownEncoderBlockCausal3D",), + up_block_types: Tuple[str] = ("UpDecoderBlockCausal3D",), + block_out_channels: Tuple[int] = (64,), + layers_per_block: int = 1, + act_fn: str = "silu", + latent_channels: int = 4, + norm_num_groups: int = 32, + sample_size: int = 32, + sample_tsize: int = 64, + scaling_factor: float = 0.18215, + force_upcast: float = True, + spatial_compression_ratio: int = 8, + time_compression_ratio: int = 4, + mid_block_add_attention: bool = True, + ): + super().__init__() + + self.time_compression_ratio = time_compression_ratio + + self.encoder = EncoderCausal3D( + in_channels=in_channels, + out_channels=latent_channels, + down_block_types=down_block_types, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + act_fn=act_fn, + norm_num_groups=norm_num_groups, + double_z=True, + time_compression_ratio=time_compression_ratio, + spatial_compression_ratio=spatial_compression_ratio, + mid_block_add_attention=mid_block_add_attention, + ) + + self.decoder = DecoderCausal3D( + in_channels=latent_channels, + out_channels=out_channels, + up_block_types=up_block_types, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + norm_num_groups=norm_num_groups, + act_fn=act_fn, + time_compression_ratio=time_compression_ratio, + spatial_compression_ratio=spatial_compression_ratio, + mid_block_add_attention=mid_block_add_attention, + ) + + self.quant_conv = nn.Conv3d(2 * latent_channels, 2 * latent_channels, kernel_size=1) + self.post_quant_conv = nn.Conv3d(latent_channels, latent_channels, kernel_size=1) + + self.use_slicing = False + self.use_spatial_tiling = False + self.use_temporal_tiling = False + + # only relevant if vae tiling is enabled + self.tile_sample_min_tsize = sample_tsize + self.tile_latent_min_tsize = sample_tsize // time_compression_ratio + + self.tile_sample_min_size = self.config.sample_size + sample_size = ( + self.config.sample_size[0] + if isinstance(self.config.sample_size, (list, tuple)) + else self.config.sample_size + ) + self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1))) + self.tile_overlap_factor = 0.25 + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (EncoderCausal3D, DecoderCausal3D)): + module.gradient_checkpointing = value + + def enable_temporal_tiling(self, use_tiling: bool = True): + self.use_temporal_tiling = use_tiling + + def disable_temporal_tiling(self): + self.enable_temporal_tiling(False) + + def enable_spatial_tiling(self, use_tiling: bool = True): + self.use_spatial_tiling = use_tiling + + def disable_spatial_tiling(self): + self.enable_spatial_tiling(False) + + def enable_tiling(self, use_tiling: bool = True): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger videos. + """ + self.enable_spatial_tiling(use_tiling) + self.enable_temporal_tiling(use_tiling) + + def disable_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing + decoding in one step. + """ + self.disable_spatial_tiling() + self.disable_temporal_tiling() + + def enable_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.use_slicing = True + + def disable_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing + decoding in one step. + """ + self.use_slicing = False + + @property + # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): + if hasattr(module, "get_processor"): + processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True) + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor + def set_attn_processor( + self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False + ): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor, _remove_lora=_remove_lora) + else: + module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor + def set_default_attn_processor(self): + """ + Disables custom attention processors and sets the default attention implementation. + """ + if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): + processor = AttnAddedKVProcessor() + elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): + processor = AttnProcessor() + else: + raise ValueError( + f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" + ) + + self.set_attn_processor(processor, _remove_lora=True) + + @apply_forward_hook + def encode( + self, x: torch.FloatTensor, return_dict: bool = True + ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]: + """ + Encode a batch of images/videos into latents. + + Args: + x (`torch.FloatTensor`): Input batch of images/videos. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple. + + Returns: + The latent representations of the encoded images/videos. If `return_dict` is True, a + [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned. + """ + assert len(x.shape) == 5, "The input tensor should have 5 dimensions." + + if self.use_temporal_tiling and x.shape[2] > self.tile_sample_min_tsize: + return self.temporal_tiled_encode(x, return_dict=return_dict) + + if self.use_spatial_tiling and (x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size): + return self.spatial_tiled_encode(x, return_dict=return_dict) + + if self.use_slicing and x.shape[0] > 1: + encoded_slices = [self.encoder(x_slice) for x_slice in x.split(1)] + h = torch.cat(encoded_slices) + else: + h = self.encoder(x) + + moments = self.quant_conv(h) + posterior = DiagonalGaussianDistribution(moments) + + if not return_dict: + return (posterior,) + + return AutoencoderKLOutput(latent_dist=posterior) + + def _decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]: + assert len(z.shape) == 5, "The input tensor should have 5 dimensions." + + if self.use_temporal_tiling and z.shape[2] > self.tile_latent_min_tsize: + return self.temporal_tiled_decode(z, return_dict=return_dict) + + if self.use_spatial_tiling and (z.shape[-1] > self.tile_latent_min_size or z.shape[-2] > self.tile_latent_min_size): + return self.spatial_tiled_decode(z, return_dict=return_dict) + + z = self.post_quant_conv(z) + dec = self.decoder(z) + + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) + + @apply_forward_hook + def decode( + self, z: torch.FloatTensor, return_dict: bool = True, generator=None + ) -> Union[DecoderOutput, torch.FloatTensor]: + """ + Decode a batch of images/videos. + + Args: + z (`torch.FloatTensor`): Input batch of latent vectors. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vae.DecoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is + returned. + + """ + if self.use_slicing and z.shape[0] > 1: + decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)] + decoded = torch.cat(decoded_slices) + else: + decoded = self._decode(z).sample + + if not return_dict: + return (decoded,) + + return DecoderOutput(sample=decoded) + + def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[-2], b.shape[-2], blend_extent) + for y in range(blend_extent): + b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (y / blend_extent) + return b + + def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[-1], b.shape[-1], blend_extent) + for x in range(blend_extent): + b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (x / blend_extent) + return b + + def blend_t(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[-3], b.shape[-3], blend_extent) + for x in range(blend_extent): + b[:, :, x, :, :] = a[:, :, -blend_extent + x, :, :] * (1 - x / blend_extent) + b[:, :, x, :, :] * (x / blend_extent) + return b + + def spatial_tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True, return_moments: bool = False) -> AutoencoderKLOutput: + r"""Encode a batch of images/videos using a tiled encoder. + + When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several + steps. This is useful to keep memory use constant regardless of image/videos size. The end result of tiled encoding is + different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the + tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the + output, but they should be much less noticeable. + + Args: + x (`torch.FloatTensor`): Input batch of images/videos. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple. + + Returns: + [`~models.autoencoder_kl.AutoencoderKLOutput`] or `tuple`: + If return_dict is True, a [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain + `tuple` is returned. + """ + overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor)) + blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor) + row_limit = self.tile_latent_min_size - blend_extent + + # Split video into tiles and encode them separately. + rows = [] + for i in range(0, x.shape[-2], overlap_size): + row = [] + for j in range(0, x.shape[-1], overlap_size): + tile = x[:, :, :, i: i + self.tile_sample_min_size, j: j + self.tile_sample_min_size] + tile = self.encoder(tile) + tile = self.quant_conv(tile) + row.append(tile) + rows.append(row) + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_extent) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_extent) + result_row.append(tile[:, :, :, :row_limit, :row_limit]) + result_rows.append(torch.cat(result_row, dim=-1)) + + moments = torch.cat(result_rows, dim=-2) + if return_moments: + return moments + + posterior = DiagonalGaussianDistribution(moments) + if not return_dict: + return (posterior,) + + return AutoencoderKLOutput(latent_dist=posterior) + + def spatial_tiled_decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]: + r""" + Decode a batch of images/videos using a tiled decoder. + + Args: + z (`torch.FloatTensor`): Input batch of latent vectors. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vae.DecoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is + returned. + """ + overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor)) + blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor) + row_limit = self.tile_sample_min_size - blend_extent + + # Split z into overlapping tiles and decode them separately. + # The tiles have an overlap to avoid seams between tiles. + rows = [] + for i in range(0, z.shape[-2], overlap_size): + row = [] + for j in range(0, z.shape[-1], overlap_size): + tile = z[:, :, :, i: i + self.tile_latent_min_size, j: j + self.tile_latent_min_size] + tile = self.post_quant_conv(tile) + decoded = self.decoder(tile) + row.append(decoded) + rows.append(row) + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_extent) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_extent) + result_row.append(tile[:, :, :, :row_limit, :row_limit]) + result_rows.append(torch.cat(result_row, dim=-1)) + + dec = torch.cat(result_rows, dim=-2) + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) + + def temporal_tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput: + + B, C, T, H, W = x.shape + overlap_size = int(self.tile_sample_min_tsize * (1 - self.tile_overlap_factor)) + blend_extent = int(self.tile_latent_min_tsize * self.tile_overlap_factor) + t_limit = self.tile_latent_min_tsize - blend_extent + + # Split the video into tiles and encode them separately. + row = [] + for i in range(0, T, overlap_size): + tile = x[:, :, i: i + self.tile_sample_min_tsize + 1, :, :] + if self.use_spatial_tiling and (tile.shape[-1] > self.tile_sample_min_size or tile.shape[-2] > self.tile_sample_min_size): + tile = self.spatial_tiled_encode(tile, return_moments=True) + else: + tile = self.encoder(tile) + tile = self.quant_conv(tile) + if i > 0: + tile = tile[:, :, 1:, :, :] + row.append(tile) + result_row = [] + for i, tile in enumerate(row): + if i > 0: + tile = self.blend_t(row[i - 1], tile, blend_extent) + result_row.append(tile[:, :, :t_limit, :, :]) + else: + result_row.append(tile[:, :, :t_limit + 1, :, :]) + + moments = torch.cat(result_row, dim=2) + posterior = DiagonalGaussianDistribution(moments) + + if not return_dict: + return (posterior,) + + return AutoencoderKLOutput(latent_dist=posterior) + + def temporal_tiled_decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]: + # Split z into overlapping tiles and decode them separately. + + B, C, T, H, W = z.shape + overlap_size = int(self.tile_latent_min_tsize * (1 - self.tile_overlap_factor)) + blend_extent = int(self.tile_sample_min_tsize * self.tile_overlap_factor) + t_limit = self.tile_sample_min_tsize - blend_extent + + row = [] + for i in range(0, T, overlap_size): + tile = z[:, :, i: i + self.tile_latent_min_tsize + 1, :, :] + if self.use_spatial_tiling and (tile.shape[-1] > self.tile_latent_min_size or tile.shape[-2] > self.tile_latent_min_size): + decoded = self.spatial_tiled_decode(tile, return_dict=True).sample + else: + tile = self.post_quant_conv(tile) + decoded = self.decoder(tile) + if i > 0: + decoded = decoded[:, :, 1:, :, :] + row.append(decoded) + result_row = [] + for i, tile in enumerate(row): + if i > 0: + tile = self.blend_t(row[i - 1], tile, blend_extent) + result_row.append(tile[:, :, :t_limit, :, :]) + else: + result_row.append(tile[:, :, :t_limit + 1, :, :]) + + dec = torch.cat(result_row, dim=2) + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) + + def forward( + self, + sample: torch.FloatTensor, + sample_posterior: bool = False, + return_dict: bool = True, + return_posterior: bool = False, + generator: Optional[torch.Generator] = None, + ) -> Union[DecoderOutput2, torch.FloatTensor]: + r""" + Args: + sample (`torch.FloatTensor`): Input sample. + sample_posterior (`bool`, *optional*, defaults to `False`): + Whether to sample from the posterior. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`DecoderOutput`] instead of a plain tuple. + """ + x = sample + posterior = self.encode(x).latent_dist + if sample_posterior: + z = posterior.sample(generator=generator) + else: + z = posterior.mode() + dec = self.decode(z).sample + + if not return_dict: + if return_posterior: + return (dec, posterior) + else: + return (dec,) + if return_posterior: + return DecoderOutput2(sample=dec, posterior=posterior) + else: + return DecoderOutput2(sample=dec) + + # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections + def fuse_qkv_projections(self): + """ + Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, + key, value) are fused. For cross-attention modules, key and value projection matrices are fused. + + + + This API is 🧪 experimental. + + + """ + self.original_attn_processors = None + + for _, attn_processor in self.attn_processors.items(): + if "Added" in str(attn_processor.__class__.__name__): + raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.") + + self.original_attn_processors = self.attn_processors + + for module in self.modules(): + if isinstance(module, Attention): + module.fuse_projections(fuse=True) + + # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections + def unfuse_qkv_projections(self): + """Disables the fused QKV projection if enabled. + + + + This API is 🧪 experimental. + + + + """ + if self.original_attn_processors is not None: + self.set_attn_processor(self.original_attn_processors) \ No newline at end of file -- Gitee From 72e2d104114fe158cc81687154387bcba69ef826 Mon Sep 17 00:00:00 2001 From: taoyuan-guo Date: Fri, 21 Feb 2025 20:35:29 +0800 Subject: [PATCH 2/5] hunyuanvideo --- .../hyvideo/vae/autoencoder_kl_causal_3d.py | 24 ++++++++----------- 1 file changed, 10 insertions(+), 14 deletions(-) diff --git a/MindIE/MindIE-Torch/built-in/foundation/hunyuan_video/hyvideo/vae/autoencoder_kl_causal_3d.py b/MindIE/MindIE-Torch/built-in/foundation/hunyuan_video/hyvideo/vae/autoencoder_kl_causal_3d.py index e635884e2f..6f10889cdd 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/hunyuan_video/hyvideo/vae/autoencoder_kl_causal_3d.py +++ b/MindIE/MindIE-Torch/built-in/foundation/hunyuan_video/hyvideo/vae/autoencoder_kl_causal_3d.py @@ -25,11 +25,11 @@ import torch.nn as nn from diffusers.configuration_utils import ConfigMixin, register_to_config try: - # This diffusers is modified and packed in the mirror. + # This diffusers package is modified and packed in the mirror. from diffusers.loaders import FromOriginalVAEMixin except ImportError: - # Use this to be compatible with the original diffusers. - from diffusers.loaders.single_file_model import FromOriginalModelMixin as FromOriginalVAEMixin + # Use this to maintain compatibility with the original diffusers library. + # from diffusers.loaders.single_file_model import FromOriginalModelMixin as FromOriginalVAEMixin from diffusers.utils.accelerate_utils import apply_forward_hook from diffusers.models.attention_processor import ( ADDED_KV_ATTENTION_PROCESSORS, @@ -179,7 +179,7 @@ class AutoencoderKLCausal3D(ModelMixin, ConfigMixin, FromOriginalVAEMixin): self.use_slicing = False @property - # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors + # Modified from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors def attn_processors(self) -> Dict[str, AttentionProcessor]: r""" Returns: @@ -203,7 +203,7 @@ class AutoencoderKLCausal3D(ModelMixin, ConfigMixin, FromOriginalVAEMixin): return processors - # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor + # Modified from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor def set_attn_processor( self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False ): @@ -223,8 +223,8 @@ class AutoencoderKLCausal3D(ModelMixin, ConfigMixin, FromOriginalVAEMixin): if isinstance(processor, dict) and len(processor) != count: raise ValueError( - f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" - f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + f"A dictionary of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please ensure you pass {count} processor classes." ) def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): @@ -240,7 +240,7 @@ class AutoencoderKLCausal3D(ModelMixin, ConfigMixin, FromOriginalVAEMixin): for name, module in self.named_children(): fn_recursive_attn_processor(name, module, processor) - # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor + # Modified from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor def set_default_attn_processor(self): """ Disables custom attention processors and sets the default attention implementation. @@ -272,8 +272,6 @@ class AutoencoderKLCausal3D(ModelMixin, ConfigMixin, FromOriginalVAEMixin): The latent representations of the encoded images/videos. If `return_dict` is True, a [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned. """ - assert len(x.shape) == 5, "The input tensor should have 5 dimensions." - if self.use_temporal_tiling and x.shape[2] > self.tile_sample_min_tsize: return self.temporal_tiled_encode(x, return_dict=return_dict) @@ -295,8 +293,6 @@ class AutoencoderKLCausal3D(ModelMixin, ConfigMixin, FromOriginalVAEMixin): return AutoencoderKLOutput(latent_dist=posterior) def _decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]: - assert len(z.shape) == 5, "The input tensor should have 5 dimensions." - if self.use_temporal_tiling and z.shape[2] > self.tile_latent_min_tsize: return self.temporal_tiled_decode(z, return_dict=return_dict) @@ -564,7 +560,7 @@ class AutoencoderKLCausal3D(ModelMixin, ConfigMixin, FromOriginalVAEMixin): else: return DecoderOutput2(sample=dec) - # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections + # Modified from diffusers.models.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections def fuse_qkv_projections(self): """ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, @@ -588,7 +584,7 @@ class AutoencoderKLCausal3D(ModelMixin, ConfigMixin, FromOriginalVAEMixin): if isinstance(module, Attention): module.fuse_projections(fuse=True) - # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections + # Modified from diffusers.models.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections def unfuse_qkv_projections(self): """Disables the fused QKV projection if enabled. -- Gitee From c81d5a782e7e7afd4e9e781a4b1cbd7dd9968732 Mon Sep 17 00:00:00 2001 From: taoyuan-guo Date: Mon, 24 Feb 2025 10:39:52 +0800 Subject: [PATCH 3/5] hunyuanvideo --- .../hunyuan_video/hyvideo/vae/autoencoder_kl_causal_3d.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/MindIE/MindIE-Torch/built-in/foundation/hunyuan_video/hyvideo/vae/autoencoder_kl_causal_3d.py b/MindIE/MindIE-Torch/built-in/foundation/hunyuan_video/hyvideo/vae/autoencoder_kl_causal_3d.py index 6f10889cdd..f8348d8b32 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/hunyuan_video/hyvideo/vae/autoencoder_kl_causal_3d.py +++ b/MindIE/MindIE-Torch/built-in/foundation/hunyuan_video/hyvideo/vae/autoencoder_kl_causal_3d.py @@ -272,6 +272,9 @@ class AutoencoderKLCausal3D(ModelMixin, ConfigMixin, FromOriginalVAEMixin): The latent representations of the encoded images/videos. If `return_dict` is True, a [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned. """ + if len(x.shape) == 5: + raise ValueError("The input tensor should have 5 dimensions.") + if self.use_temporal_tiling and x.shape[2] > self.tile_sample_min_tsize: return self.temporal_tiled_encode(x, return_dict=return_dict) @@ -293,6 +296,9 @@ class AutoencoderKLCausal3D(ModelMixin, ConfigMixin, FromOriginalVAEMixin): return AutoencoderKLOutput(latent_dist=posterior) def _decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]: + if len(z.shape) == 5: + raise ValueError("The input tensor should have 5 dimensions.") + if self.use_temporal_tiling and z.shape[2] > self.tile_latent_min_tsize: return self.temporal_tiled_decode(z, return_dict=return_dict) -- Gitee From 4d42621ce66559f63e637d00263ca2162f6f183b Mon Sep 17 00:00:00 2001 From: taoyuan-guo Date: Mon, 24 Feb 2025 10:41:21 +0800 Subject: [PATCH 4/5] hunyuanvideo --- .../hunyuan_video/hyvideo/vae/autoencoder_kl_causal_3d.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/MindIE/MindIE-Torch/built-in/foundation/hunyuan_video/hyvideo/vae/autoencoder_kl_causal_3d.py b/MindIE/MindIE-Torch/built-in/foundation/hunyuan_video/hyvideo/vae/autoencoder_kl_causal_3d.py index f8348d8b32..c6f162dd4f 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/hunyuan_video/hyvideo/vae/autoencoder_kl_causal_3d.py +++ b/MindIE/MindIE-Torch/built-in/foundation/hunyuan_video/hyvideo/vae/autoencoder_kl_causal_3d.py @@ -272,7 +272,7 @@ class AutoencoderKLCausal3D(ModelMixin, ConfigMixin, FromOriginalVAEMixin): The latent representations of the encoded images/videos. If `return_dict` is True, a [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned. """ - if len(x.shape) == 5: + if len(x.shape) != 5: raise ValueError("The input tensor should have 5 dimensions.") if self.use_temporal_tiling and x.shape[2] > self.tile_sample_min_tsize: @@ -296,7 +296,7 @@ class AutoencoderKLCausal3D(ModelMixin, ConfigMixin, FromOriginalVAEMixin): return AutoencoderKLOutput(latent_dist=posterior) def _decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]: - if len(z.shape) == 5: + if len(z.shape) != 5: raise ValueError("The input tensor should have 5 dimensions.") if self.use_temporal_tiling and z.shape[2] > self.tile_latent_min_tsize: -- Gitee From 1f416d28a210d89f371b6853a110fa657779f03e Mon Sep 17 00:00:00 2001 From: taoyuan-guo Date: Mon, 24 Feb 2025 15:44:04 +0800 Subject: [PATCH 5/5] =?UTF-8?q?hunyuanvideo=E5=88=9D=E7=89=88=E4=BB=A3?= =?UTF-8?q?=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../hyvideo/vae/mod_attention.py | 107 ++++++++++++++++++ 1 file changed, 107 insertions(+) create mode 100644 MindIE/MindIE-Torch/built-in/foundation/hunyuan_video/hyvideo/vae/mod_attention.py diff --git a/MindIE/MindIE-Torch/built-in/foundation/hunyuan_video/hyvideo/vae/mod_attention.py b/MindIE/MindIE-Torch/built-in/foundation/hunyuan_video/hyvideo/vae/mod_attention.py new file mode 100644 index 0000000000..4a7b9c18f7 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/hunyuan_video/hyvideo/vae/mod_attention.py @@ -0,0 +1,107 @@ +from typing import Optional +import torch +import torch.nn.functional as F +import math +import torch_npu +from diffusers.models.attention_processor import Attention +MAX_TOKEN = 2147483647 + + +class AttnProcessor2_0: + r""" + Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). + """ + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + temb: Optional[torch.Tensor] = None, + *args, + **kwargs, + ) -> torch.Tensor: + + residual = hidden_states + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + scale = 1.0 / math.sqrt(head_dim) + hidden_states = torch_npu.npu_fusion_attention( + query, key, value, + head_num=attn.heads, + input_layout="BNSD", + scale=scale, + pse=None, + atten_mask=attention_mask, + pre_tockens=MAX_TOKEN, + next_tockens=MAX_TOKEN, + keep_prob=1.0, + sync=False + )[0] + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states \ No newline at end of file -- Gitee