diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogview3/README.md b/MindIE/MindIE-Torch/built-in/foundation/cogview3/README.md new file mode 100644 index 0000000000000000000000000000000000000000..e161ab5c1018a5ab7ac79eaacac77eefaba2e798 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/cogview3/README.md @@ -0,0 +1,166 @@ +## 一、准备运行环境 + + **表 1** 版本配套表 + + | 配套 | 版本 | 环境准备指导 | + | ----- | ----- |-----| + | Python | 3.10.12 | - | + | torch | 2.4.0 | - | + +### 1.1 获取CANN&MindIE安装包&环境准备 +- [800I A2](https://www.hiascend.com/developer/download/community/result?module=pt+ie+cann&product=4&model=32) +- [Duo卡](https://www.hiascend.com/developer/download/community/result?module=pt+ie+cann&product=2&model=17) +- [环境准备指导](https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/80RC2alpha002/softwareinst/instg/instg_0001.html) + +### 1.2 CANN安装 +```shell +# 增加软件包可执行权限,{version}表示软件版本号,{arch}表示CPU架构,{soc}表示昇腾AI处理器的版本。 +chmod +x ./Ascend-cann-toolkit_{version}_linux-{arch}.run +chmod +x ./Ascend-cann-kernels-{soc}_{version}_linux.run +# 校验软件包安装文件的一致性和完整性 +./Ascend-cann-toolkit_{version}_linux-{arch}.run --check +./Ascend-cann-kernels-{soc}_{version}_linux.run --check +# 安装 +./Ascend-cann-toolkit_{version}_linux-{arch}.run --install +./Ascend-cann-kernels-{soc}_{version}_linux.run --install + +# 设置环境变量 +source /usr/local/Ascend/ascend-toolkit/set_env.sh +``` + +### 1.3 MindIE安装 +```shell +# 增加软件包可执行权限,{version}表示软件版本号,{arch}表示CPU架构。 +chmod +x ./Ascend-mindie_${version}_linux-${arch}.run +./Ascend-mindie_${version}_linux-${arch}.run --check + +# 方式一:默认路径安装 +./Ascend-mindie_${version}_linux-${arch}.run --install +# 设置环境变量 +cd /usr/local/Ascend/mindie && source set_env.sh + +# 方式二:指定路径安装 +./Ascend-mindie_${version}_linux-${arch}.run --install-path=${AieInstallPath} +# 设置环境变量 +cd ${AieInstallPath}/mindie && source set_env.sh +``` + +### 1.4 Torch_npu安装 +安装pytorch框架 版本2.4.0 +[安装包下载](https://download.pytorch.org/whl/cpu/torch/) + +使用pip安装 +```shell +# {version}表示软件版本号,{arch}表示CPU架构。 +pip install torch-${version}-cp310-cp310-linux_${arch}.whl +``` +下载 pytorch_v{pytorchversion}_py{pythonversion}.tar.gz +```shell +tar -xzvf pytorch_v{pytorchversion}_py{pythonversion}.tar.gz +# 解压后,会有whl包 +pip install torch_npu-{pytorchversion}.xxxx.{arch}.whl +``` +## 二、下载本仓库 + +### 2.1 下载到本地 +```shell + git clone https://gitee.com/ascend/ModelZoo-PyTorch.git +``` + +## 三、CogView3使用 + +### 3.1 权重及配置文件说明 +1. CogView3权重路径: +```shell +https://huggingface.co/THUDM/CogView3-Plus-3B/tree/main +``` +- 修改该权重的model_index.json +```shell +{ + "_class_name": "CogView3PlusPipeline", + "_diffusers_version": "0.31.0", + "scheduler": [ + "cogview3plus", + "CogVideoXDDIMScheduler" + ], + "text_encoder": [ + "transformers", + "T5EncoderModel" + ], + "tokenizer": [ + "transformers", + "T5Tokenizer" + ], + "transformer": [ + "cogview3plus", + "CogView3PlusTransformer2DModel" + ], + "vae": [ + "diffusers", + "AutoencoderKL" + ] +} +``` +2. scheduler权重链接: +```shell +https://huggingface.co/THUDM/CogView3-Plus-3B/tree/main/scheduler +``` +3. text_encoder权重链接: +```shell +https://huggingface.co/THUDM/CogView3-Plus-3B/tree/main/text_encoder +``` +4. tokenizer权重链接: +```shell +https://huggingface.co/THUDM/CogView3-Plus-3B/tree/main/tokenizer +``` +5. transformer权重链接: +```shell +https://huggingface.co/THUDM/CogView3-Plus-3B/tree/main/transformer +``` +6. vae权重链接: +```shell +https://huggingface.co/THUDM/CogView3-Plus-3B/tree/main/vae +``` +7. 各模型的配置文件、权重文件的层级样例如下所示。 +```commandline +|----CogView3B +| |---- configuration.json +| |---- model_index.json +| |---- scheduler +| | |---- scheduler_config.json +| |---- text_encoder +| | |---- config.json +| | |---- 模型权重 +| |---- tokenizer +| | |---- config.json +| | |---- 模型权重 +| |---- transformer +| | |---- config.json +| | |---- 模型权重 +| |---- vae +| | |---- config.json +| | |---- 模型权重 +``` + +### 3.2 单卡单prompt功能测试 +设置权重路径 +```shell +model_path='/data/CogView3B' +``` +执行命令: +```shell +python inference_cogview3plus.py \ + --model_path ${model_path} \ + --device_id 0 \ + --width 1024 \ + --height 1024 \ + --num_inference_steps 50 \ + --dtype bf16 +``` +参数说明: +- model_path:权重路径,包含scheduler、text_encoder、tokenizer、transformer、vae,5个模型的配置文件及权重。 +- device_id:推理设备ID。 +- width:需要生成的图像的宽。 +- height: 需要生成的图像的高。 +- num_inference_steps:推理迭代步数。 +- dtype: 数据类型。目前只支持bf16。 diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/__init__.py b/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1139593a36dcd44b33934bf240adf6ab4a477613 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/__init__.py @@ -0,0 +1,3 @@ +from .pipeline import CogView3PlusPipeline, DiffusionPipeline +from .schedulers import CogVideoXDDIMScheduler, SchedulerMixin +from .models import CogView3PlusTransformer2DModel, ModelMixin \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/layers/__init__.py b/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/layers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..09760b9fd0e5838eba0865adbac38a31d59768a1 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/layers/__init__.py @@ -0,0 +1,3 @@ +from .normalization import CogView3PlusAdaLayerNormZeroTextImage, AdaLayerNormContinuous +from .embeddings import CogView3CombinedTimestepSizeEmbeddings, CogView3PlusPatchEmbed +from .linear import QKVLinear \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/layers/embeddings.py b/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/layers/embeddings.py new file mode 100644 index 0000000000000000000000000000000000000000..129384dffc0fb92090012e46a8d89e7f4d0a1d6c --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/layers/embeddings.py @@ -0,0 +1,304 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Optional + +import torch +from torch import nn +from diffusers.models.activations import get_activation + + +def get_timestep_embedding( + timesteps: torch.Tensor, + embedding_dim: int, + flip_sin_to_cos: bool = False, + downscale_freq_shift: float = 1, + max_period: int = 10000, +): + half_dim = embedding_dim // 2 + exponent = -math.log(max_period) * torch.arange( + start=0, end=half_dim, dtype=torch.float32, device=timesteps.device + ) + exponent = exponent / (half_dim - downscale_freq_shift) + + emb = torch.exp(exponent) + emb = timesteps[:, None].float() * emb[None, :] + + # concat sine and cosine embeddings + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1) + + # flip sine and cosine embeddings + if flip_sin_to_cos: + emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1) + + # zero pad + if embedding_dim % 2 == 1: + emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) + return emb + + +def get_2d_sincos_pos_embed( + embed_dim, + grid_size, + interpolation_scale=1.0, + base_size=16, +): + if isinstance(grid_size, int): + grid_size = (grid_size, grid_size) + + grid_h = ( + torch.arange(grid_size[0], dtype=torch.float32) + / (grid_size[0] / base_size) + / interpolation_scale + ) + grid_w = ( + torch.arange(grid_size[1], dtype=torch.float32) + / (grid_size[1] / base_size) + / interpolation_scale + ) + grid = torch.meshgrid(grid_w, grid_h, indexing="xy") # here w goes first + grid = torch.stack(grid, dim=0) + + grid = grid.reshape([2, 1, grid_size[1], grid_size[0]]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + return pos_embed + + +def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): + r""" + This function generates 2D sinusoidal positional embeddings from a grid. + + Args: + embed_dim (`int`): The embedding dimension. + grid (`torch.Tensor`): Grid of positions with shape `(H * W,)`. + + Returns: + `torch.Tensor`: The 2D sinusoidal positional embeddings with shape `(H * W, embed_dim)` + """ + if embed_dim % 2 != 0: + raise ValueError("embed_dim must be divisible by 2") + + # use half of dimensions to encode grid_h + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) + + emb = torch.concat([emb_h, emb_w], dim=1) # (H*W, D) + return emb + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """ + This function generates 1D positional embeddings from a grid. + + Args: + embed_dim (`int`): The embedding dimension `D` + pos (`torch.Tensor`): 1D tensor of positions with shape `(M,)` + + Returns: + `torch.Tensor`: Sinusoidal positional embeddings of shape `(M, D)`. + """ + if embed_dim % 2 != 0: + raise ValueError("embed_dim must be divisible by 2") + + omega = torch.arange(embed_dim // 2, device=pos.device, dtype=torch.float64) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = torch.outer(pos, omega) # (M, D/2), outer product + + emb_sin = torch.sin(out) # (M, D/2) + emb_cos = torch.cos(out) # (M, D/2) + + emb = torch.concat([emb_sin, emb_cos], dim=1) # (M, D) + return emb + + +class Timesteps(nn.Module): + def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float): + super().__init__() + self.num_channels = num_channels + self.flip_sin_to_cos = flip_sin_to_cos + self.downscale_freq_shift = downscale_freq_shift + + def forward(self, timesteps): + t_emb = get_timestep_embedding( + timesteps, + self.num_channels, + flip_sin_to_cos=self.flip_sin_to_cos, + downscale_freq_shift=self.downscale_freq_shift, + ) + return t_emb + + +class TimestepEmbedding(nn.Module): + def __init__( + self, + in_channels: int, + time_embed_dim: int, + act_fn: str = "silu", + out_dim: int = None, + post_act_fn: Optional[str] = None, + cond_proj_dim=None, + sample_proj_bias=True, + ): + super().__init__() + + self.linear_1 = nn.Linear(in_channels, time_embed_dim, sample_proj_bias) + + if cond_proj_dim is not None: + self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False) + else: + self.cond_proj = None + + self.act = get_activation(act_fn) + + if out_dim is not None: + time_embed_dim_out = out_dim + else: + time_embed_dim_out = time_embed_dim + self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out, sample_proj_bias) + + if post_act_fn is None: + self.post_act = None + else: + self.post_act = get_activation(post_act_fn) + + def forward(self, sample, condition=None): + if condition is not None: + sample = sample + self.cond_proj(condition) + sample = self.linear_1(sample) + + if self.act is not None: + sample = self.act(sample) + + sample = self.linear_2(sample) + + if self.post_act is not None: + sample = self.post_act(sample) + return sample + + +class PixArtAlphaTextProjection(nn.Module): + """ + Projects caption embeddings. Also handles dropout for classifier-free guidance. + """ + + def __init__(self, in_features, hidden_size, out_features=None, act_fn="gelu_tanh"): + super().__init__() + if out_features is None: + out_features = hidden_size + self.linear_1 = nn.Linear(in_features=in_features, out_features=hidden_size, bias=True) + if act_fn == "gelu_tanh": + self.act_1 = nn.GELU(approximate="tanh") + elif act_fn == "silu": + self.act_1 = nn.SiLU() + else: + raise ValueError(f"Unknown activation function: {act_fn}") + self.linear_2 = nn.Linear(in_features=hidden_size, out_features=out_features, bias=True) + + def forward(self, caption): + hidden_states = self.linear_1(caption) + hidden_states = self.act_1(hidden_states) + hidden_states = self.linear_2(hidden_states) + return hidden_states + + +class CogView3CombinedTimestepSizeEmbeddings(nn.Module): + def __init__(self, embedding_dim: int, condition_dim: int, pooled_projection_dim: int, timesteps_dim: int = 256): + super().__init__() + + self.time_proj = Timesteps(num_channels=timesteps_dim, flip_sin_to_cos=True, downscale_freq_shift=0) + self.condition_proj = Timesteps(num_channels=condition_dim, flip_sin_to_cos=True, downscale_freq_shift=0) + self.timestep_embedder = TimestepEmbedding(in_channels=timesteps_dim, time_embed_dim=embedding_dim) + self.condition_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu") + + def forward( + self, + timestep: torch.Tensor, + original_size: torch.Tensor, + target_size: torch.Tensor, + crop_coords: torch.Tensor, + hidden_dtype: torch.dtype, + ) -> torch.Tensor: + timesteps_proj = self.time_proj(timestep) + + original_size_proj = self.condition_proj(original_size.flatten()).view(original_size.size(0), -1) + crop_coords_proj = self.condition_proj(crop_coords.flatten()).view(crop_coords.size(0), -1) + target_size_proj = self.condition_proj(target_size.flatten()).view(target_size.size(0), -1) + + condition_proj = torch.cat([original_size_proj, crop_coords_proj, target_size_proj], dim=1) + + timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (B, embedding_dim) + condition_emb = self.condition_embedder(condition_proj.to(dtype=hidden_dtype)) # (B, embedding_dim) + + conditioning = timesteps_emb + condition_emb + return conditioning + + +class CogView3PlusPatchEmbed(nn.Module): + def __init__( + self, + in_channels: int = 16, + hidden_size: int = 2560, + patch_size: int = 2, + text_hidden_size: int = 4096, + pos_embed_max_size: int = 128, + ): + super().__init__() + self.in_channels = in_channels + self.hidden_size = hidden_size + self.patch_size = patch_size + self.text_hidden_size = text_hidden_size + self.pos_embed_max_size = pos_embed_max_size + # Linear projection for image patches + self.proj = nn.Linear(in_channels * patch_size**2, hidden_size) + + # Linear projection for text embeddings + self.text_proj = nn.Linear(text_hidden_size, hidden_size) + + pos_embed = get_2d_sincos_pos_embed( + hidden_size, pos_embed_max_size, base_size=pos_embed_max_size + ) + pos_embed = pos_embed.reshape(pos_embed_max_size, pos_embed_max_size, hidden_size) + self.register_buffer("pos_embed", pos_embed.float(), persistent=False) + + def forward(self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor) -> torch.Tensor: + batch_size, channel, height, width = hidden_states.shape + + if height % self.patch_size != 0 or width % self.patch_size != 0: + raise ValueError("Height and width must be divisible by patch size") + + height = height // self.patch_size + width = width // self.patch_size + hidden_states = hidden_states.view(batch_size, channel, height, self.patch_size, width, self.patch_size) + hidden_states = hidden_states.permute(0, 2, 4, 1, 3, 5).contiguous() + hidden_states = hidden_states.view(batch_size, height * width, channel * self.patch_size * self.patch_size) + + # Project the patches + hidden_states = self.proj(hidden_states) + encoder_hidden_states = self.text_proj(encoder_hidden_states) + hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) + + # Calculate text_length + text_length = encoder_hidden_states.shape[1] + + image_pos_embed = self.pos_embed[:height, :width].reshape(height * width, -1) + text_pos_embed = torch.zeros( + (text_length, self.hidden_size), dtype=image_pos_embed.dtype, device=image_pos_embed.device + ) + pos_embed = torch.cat([text_pos_embed, image_pos_embed], dim=0)[None, ...] + + return (hidden_states + pos_embed).to(hidden_states.dtype) \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/layers/linear.py b/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/layers/linear.py new file mode 100644 index 0000000000000000000000000000000000000000..57fe8d55dc4d030b1def6f8e7a97e9049a69fdc9 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/layers/linear.py @@ -0,0 +1,48 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch +import torch.nn as nn + + +class QKVLinear(nn.Module): + def __init__(self, attention_dim, hidden_size, qkv_bias=True, device=None, dtype=None): + super(QKVLinear, self).__init__() + self.attention_dim = attention_dim + self.hidden_size = hidden_size + self.qkv_bias = qkv_bias + + factory_kwargs = {"device": device, "dtype": dtype} + + self.weight = nn.Parameter(torch.empty([self.attention_dim, 3 * self.hidden_size], **factory_kwargs)) + if self.qkv_bias: + self.bias = nn.Parameter(torch.empty([3 * self.hidden_size], **factory_kwargs)) + + def forward(self, hidden_states): + + if not self.qkv_bias: + qkv = torch.matmul(hidden_states, self.weight) + else: + qkv = torch.addmm( + self.bias, + hidden_states.view(hidden_states.size(0) * hidden_states.size(1), hidden_states.size(2)), + self.weight, + beta=1, + alpha=1 + ) + + return qkv \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/layers/normalization.py b/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/layers/normalization.py new file mode 100644 index 0000000000000000000000000000000000000000..1ec0a5b15c93a8f3635abb42b4588e42c0daef9b --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/layers/normalization.py @@ -0,0 +1,177 @@ +# coding=utf-8 +# Copyright 2024 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numbers +from typing import Optional, Tuple +from dataclasses import dataclass + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class RMSNorm(nn.Module): + def __init__(self, dim, eps: float, elementwise_affine: bool = True, bias: bool = False): + super().__init__() + + self.eps = eps + self.elementwise_affine = elementwise_affine + + if isinstance(dim, numbers.Integral): + dim = (dim,) + + self.dim = torch.Size(dim) + + self.weight = None + self.bias = None + + if elementwise_affine: + self.weight = nn.Parameter(torch.ones(dim)) + if bias: + self.bias = nn.Parameter(torch.zeros(dim)) + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.eps) + + if self.weight is not None: + # convert into half-precision if necessary + if self.weight.dtype in [torch.float16, torch.bfloat16]: + hidden_states = hidden_states.to(self.weight.dtype) + hidden_states = hidden_states * self.weight + if self.bias is not None: + hidden_states = hidden_states + self.bias + else: + hidden_states = hidden_states.to(input_dtype) + + return hidden_states + + +@dataclass +class ChunkParam: + gate_msa: torch.Tensor + shift_mlp: torch.Tensor + scale_mlp: torch.Tensor + gate_mlp: torch.Tensor + context: torch.Tensor + c_gate_msa: torch.Tensor + c_shift_mlp: torch.Tensor + c_scale_mlp: torch.Tensor + c_gate_mlp: torch.Tensor + + +class CogView3PlusAdaLayerNormZeroTextImage(nn.Module): + r""" + Norm layer adaptive layer norm zero (adaLN-Zero). + + Parameters: + embedding_dim (`int`): The size of each embedding vector. + num_embeddings (`int`): The size of the embeddings dictionary. + """ + + def __init__(self, embedding_dim: int, dim: int): + super().__init__() + + self.silu = nn.SiLU() + self.linear = nn.Linear(embedding_dim, 12 * dim, bias=True) + self.norm_x = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5) + self.norm_c = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5) + + def forward( + self, + x: torch.Tensor, + context: torch.Tensor, + emb: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + emb = self.linear(self.silu(emb)) + ( + shift_msa, + scale_msa, + gate_msa, + shift_mlp, + scale_mlp, + gate_mlp, + c_shift_msa, + c_scale_msa, + c_gate_msa, + c_shift_mlp, + c_scale_mlp, + c_gate_mlp, + ) = emb.chunk(12, dim=1) + normed_x = self.norm_x(x) + normed_context = self.norm_c(context) + x = normed_x * (1 + scale_msa[:, None]) + shift_msa[:, None] + context = normed_context * (1 + c_scale_msa[:, None]) + c_shift_msa[:, None] + return x, ChunkParam( + gate_msa, shift_mlp, scale_mlp, gate_mlp, context, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp + ) + + +class FP32LayerNorm(nn.LayerNorm): + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + origin_dtype = inputs.dtype + return F.layer_norm( + inputs.float(), + self.normalized_shape, + self.weight.float() if self.weight is not None else None, + self.bias.float() if self.bias is not None else None, + self.eps, + ).to(origin_dtype) + + +class LpNorm(nn.Module): + def __init__(self, p: int = 2, dim: int = -1, eps: float = 1e-12): + super().__init__() + + self.p = p + self.dim = dim + self.eps = eps + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + return F.normalize(hidden_states, p=self.p, dim=self.dim, eps=self.eps) + + +class AdaLayerNormContinuous(nn.Module): + def __init__( + self, + embedding_dim: int, + conditioning_embedding_dim: int, + # NOTE: It is a bit weird that the norm layer can be configured to have scale and shift parameters + # because the output is immediately scaled and shifted by the projected conditioning embeddings. + # Note that AdaLayerNorm does not let the norm layer have scale and shift parameters. + # However, this is how it was implemented in the original code, and it's rather likely you should + # set `elementwise_affine` to False. + elementwise_affine=True, + eps=1e-5, + bias=True, + norm_type="layer_norm", + ): + super().__init__() + self.silu = nn.SiLU() + self.linear = nn.Linear(conditioning_embedding_dim, embedding_dim * 2, bias=bias) + if norm_type == "layer_norm": + self.norm = nn.LayerNorm(embedding_dim, eps, elementwise_affine, bias) + elif norm_type == "rms_norm": + self.norm = RMSNorm(embedding_dim, eps, elementwise_affine) + else: + raise ValueError(f"unknown norm_type {norm_type}") + + def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor) -> torch.Tensor: + # convert back to the original dtype in case `conditioning_embedding`` is upcasted to float32 (needed for hunyuanDiT) + emb = self.linear(self.silu(conditioning_embedding).to(x.dtype)) + scale, shift = torch.chunk(emb, 2, dim=1) + x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :] + return x \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/models/__init__.py b/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b3c595bfccdd80de85b0883dc5246eb86c19d436 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/models/__init__.py @@ -0,0 +1,2 @@ +from .transformer_cogview3plus import CogView3PlusTransformer2DModel +from .modeling_utils import ModelMixin \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/models/activations.py b/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/models/activations.py new file mode 100644 index 0000000000000000000000000000000000000000..5bb3783ae4bb753995bcf9ddf5e23f3ee05f6ef1 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/models/activations.py @@ -0,0 +1,163 @@ +# coding=utf-8 +# Copyright 2024 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn.functional as F +from torch import nn + +from diffusers.utils import deprecate +from diffusers.utils.import_utils import is_torch_npu_available + +if is_torch_npu_available(): + import torch_npu + +ACTIVATION_FUNCTIONS = { + "swish": nn.SiLU(), + "silu": nn.SiLU(), + "mish": nn.Mish(), + "gelu": nn.GELU(), + "relu": nn.ReLU(), +} + + +def get_activation(act_fn: str) -> nn.Module: + """Helper function to get activation function from string. + + Args: + act_fn (str): Name of activation function. + + Returns: + nn.Module: Activation function. + """ + + act_fn = act_fn.lower() + if act_fn in ACTIVATION_FUNCTIONS: + return ACTIVATION_FUNCTIONS[act_fn] + else: + raise ValueError(f"Unsupported activation function: {act_fn}") + + +class FP32SiLU(nn.Module): + r""" + SiLU activation function with input upcasted to torch.float32. + """ + + def __init__(self): + super().__init__() + + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + return F.silu(inputs.float(), inplace=False).to(inputs.dtype) + + +class GELU(nn.Module): + r""" + GELU activation function with tanh approximation support with `approximate="tanh"`. + + Parameters: + dim_in (`int`): The number of channels in the input. + dim_out (`int`): The number of channels in the output. + approximate (`str`, *optional*, defaults to `"none"`): If `"tanh"`, use tanh approximation. + bias (`bool`, defaults to True): Whether to use a bias in the linear layer. + """ + + def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", bias: bool = True): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out, bias=bias) + self.approximate = approximate + + def gelu(self, gate: torch.Tensor) -> torch.Tensor: + return F.gelu(gate, approximate=self.approximate) + + def forward(self, hidden_states): + hidden_states = self.proj(hidden_states) + hidden_states = self.gelu(hidden_states) + return hidden_states + + +class GEGLU(nn.Module): + r""" + Parameters: + dim_in (`int`): The number of channels in the input. + dim_out (`int`): The number of channels in the output. + bias (`bool`, defaults to True): Whether to use a bias in the linear layer. + """ + + def __init__(self, dim_in: int, dim_out: int, bias: bool = True): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out * 2, bias=bias) + + def gelu(self, gate: torch.Tensor) -> torch.Tensor: + return F.gelu(gate) + + def forward(self, hidden_states, *args, **kwargs): + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + hidden_states = self.proj(hidden_states) + if is_torch_npu_available(): + # using torch_npu.npu_geglu can run faster and save memory on NPU. + return torch_npu.npu_geglu(hidden_states, dim=-1, approximate=1)[0] + else: + hidden_states, gate = hidden_states.chunk(2, dim=-1) + return hidden_states * self.gelu(gate) + + +class SwiGLU(nn.Module): + r""" + Parameters: + dim_in (`int`): The number of channels in the input. + dim_out (`int`): The number of channels in the output. + bias (`bool`, defaults to True): Whether to use a bias in the linear layer. + """ + + def __init__(self, dim_in: int, dim_out: int, bias: bool = True): + super().__init__() + + self.proj = nn.Linear(dim_in, dim_out * 2, bias=bias) + self.activation = nn.SiLU() + + def forward(self, hidden_states): + hidden_states = self.proj(hidden_states) + hidden_states, gate = hidden_states.chunk(2, dim=-1) + return hidden_states * self.activation(gate) + + +class ApproximateGELU(nn.Module): + r""" + Parameters: + dim_in (`int`): The number of channels in the input. + dim_out (`int`): The number of channels in the output. + bias (`bool`, defaults to True): Whether to use a bias in the linear layer. + """ + + def __init__(self, dim_in: int, dim_out: int, bias: bool = True): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out, bias=bias) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.proj(x) + return x * torch.sigmoid(1.702 * x) + + +class LinearActivation(nn.Module): + def __init__(self, dim_in: int, dim_out: int, bias: bool = True, activation: str = "silu"): + super().__init__() + + self.proj = nn.Linear(dim_in, dim_out, bias=bias) + self.activation = get_activation(activation) + + def forward(self, hidden_states): + hidden_states = self.proj(hidden_states) + return self.activation(hidden_states) \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/models/attention.py b/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/models/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..a7a559ff2f529de83c4620249cd37c3afe59fefb --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/models/attention.py @@ -0,0 +1,87 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +import torch +from torch import nn + +from diffusers.utils import deprecate, logging +from .activations import GEGLU, GELU, ApproximateGELU, LinearActivation, SwiGLU + + +logger = logging.get_logger(__name__) + + +class FeedForward(nn.Module): + r""" + A feed-forward layer. + + Parameters: + dim (`int`): The number of channels in the input. + dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`. + mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + final_dropout (`bool` *optional*, defaults to False): Apply a final dropout. + bias (`bool`, defaults to True): Whether to use a bias in the linear layer. + """ + + def __init__( + self, + dim: int, + dim_out: Optional[int] = None, + mult: int = 4, + dropout: float = 0.0, + activation_fn: str = "geglu", + final_dropout: bool = False, + inner_dim=None, + bias: bool = True, + ): + super().__init__() + if inner_dim is None: + inner_dim = int(dim * mult) + dim_out = dim_out if dim_out is not None else dim + + if activation_fn == "gelu": + act_fn = GELU(dim, inner_dim, bias=bias) + if activation_fn == "gelu-approximate": + act_fn = GELU(dim, inner_dim, approximate="tanh", bias=bias) + elif activation_fn == "geglu": + act_fn = GEGLU(dim, inner_dim, bias=bias) + elif activation_fn == "geglu-approximate": + act_fn = ApproximateGELU(dim, inner_dim, bias=bias) + elif activation_fn == "swiglu": + act_fn = SwiGLU(dim, inner_dim, bias=bias) + elif activation_fn == "linear-silu": + act_fn = LinearActivation(dim, inner_dim, bias=bias, activation="silu") + + self.net = nn.ModuleList([]) + # project in + self.net.append(act_fn) + # project dropout + self.net.append(nn.Dropout(dropout)) + # project out + self.net.append(nn.Linear(inner_dim, dim_out, bias=bias)) + # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout + if final_dropout: + self.net.append(nn.Dropout(dropout)) + + def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + for module in self.net: + hidden_states = module(hidden_states) + return hidden_states \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/models/attention_processor.py b/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/models/attention_processor.py new file mode 100644 index 0000000000000000000000000000000000000000..c197a989b7c837ec1f6c588dd03633268c62175f --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/models/attention_processor.py @@ -0,0 +1,348 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Optional + +import torch +import torch.nn.functional as F +from torch import nn +import torch_npu + +from diffusers.utils import logging +from diffusers.utils.torch_utils import maybe_allow_in_graph + +from ..layers import QKVLinear + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@maybe_allow_in_graph +class Attention(nn.Module): + def __init__( + self, + query_dim: int, + cross_attention_dim: Optional[int] = None, + heads: int = 8, + kv_heads: Optional[int] = None, + dim_head: int = 64, + dropout: float = 0.0, + bias: bool = False, + upcast_attention: bool = False, + upcast_softmax: bool = False, + cross_attention_norm: Optional[str] = None, + cross_attention_norm_num_groups: int = 32, + qk_norm: Optional[str] = None, + added_kv_proj_dim: Optional[int] = None, + added_proj_bias: Optional[bool] = True, + norm_num_groups: Optional[int] = None, + out_bias: bool = True, + scale_qk: bool = True, + only_cross_attention: bool = False, + eps: float = 1e-5, + rescale_output_factor: float = 1.0, + residual_connection: bool = False, + _from_deprecated_attn_block: bool = False, + processor: Optional["AttnProcessor"] = None, + out_dim: int = None, + out_context_dim: int = None, + context_pre_only=None, + pre_only=False, + elementwise_affine: bool = True, + is_causal: bool = False, + ): + super().__init__() + + # To prevent circular import. + from ..layers.normalization import FP32LayerNorm, LpNorm, RMSNorm + + self.inner_dim = out_dim if out_dim is not None else dim_head * heads + self.inner_kv_dim = self.inner_dim if kv_heads is None else dim_head * kv_heads + self.query_dim = query_dim + self.use_bias = bias + self.is_cross_attention = cross_attention_dim is not None + self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim + self.upcast_attention = upcast_attention + self.upcast_softmax = upcast_softmax + self.rescale_output_factor = rescale_output_factor + self.residual_connection = residual_connection + self.dropout = dropout + self.fused_projections = False + self.out_dim = out_dim if out_dim is not None else query_dim + self.out_context_dim = out_context_dim if out_context_dim is not None else query_dim + self.context_pre_only = context_pre_only + self.pre_only = pre_only + self.is_causal = is_causal + + # we make use of this private variable to know whether this class is loaded + # with an deprecated state dict so that we can convert it on the fly + self._from_deprecated_attn_block = _from_deprecated_attn_block + + self.scale_qk = scale_qk + self.scale = dim_head**-0.5 if self.scale_qk else 1.0 + + self.heads = out_dim // dim_head if out_dim is not None else heads + # for slice_size > 0 the attention score computation + # is split across the batch axis to save memory + # You can set slice_size with `set_attention_slice` + self.sliceable_head_dim = heads + + self.added_kv_proj_dim = added_kv_proj_dim + self.only_cross_attention = only_cross_attention + + if self.added_kv_proj_dim is None and self.only_cross_attention: + raise ValueError( + "`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`." + ) + + if norm_num_groups is not None: + self.group_norm = nn.GroupNorm(num_channels=query_dim, num_groups=norm_num_groups, eps=eps, affine=True) + else: + self.group_norm = None + + self.spatial_norm = None + + if qk_norm is None: + self.norm_q = None + self.norm_k = None + elif qk_norm == "layer_norm": + self.norm_q = nn.LayerNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine) + self.norm_k = nn.LayerNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine) + elif qk_norm == "fp32_layer_norm": + self.norm_q = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps) + self.norm_k = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps) + elif qk_norm == "layer_norm_across_heads": + # Lumina applies qk norm across all heads + self.norm_q = nn.LayerNorm(dim_head * heads, eps=eps) + self.norm_k = nn.LayerNorm(dim_head * kv_heads, eps=eps) + elif qk_norm == "rms_norm": + self.norm_q = RMSNorm(dim_head, eps=eps) + self.norm_k = RMSNorm(dim_head, eps=eps) + elif qk_norm == "rms_norm_across_heads": + # LTX applies qk norm across all heads + self.norm_q = RMSNorm(dim_head * heads, eps=eps) + self.norm_k = RMSNorm(dim_head * kv_heads, eps=eps) + elif qk_norm == "l2": + self.norm_q = LpNorm(p=2, dim=-1, eps=eps) + self.norm_k = LpNorm(p=2, dim=-1, eps=eps) + else: + raise ValueError(f"unknown qk_norm: {qk_norm}. Should be None,'layer_norm','fp32_layer_norm','rms_norm'") + + if cross_attention_norm is None: + self.norm_cross = None + elif cross_attention_norm == "layer_norm": + self.norm_cross = nn.LayerNorm(self.cross_attention_dim) + elif cross_attention_norm == "group_norm": + if self.added_kv_proj_dim is not None: + # The given `encoder_hidden_states` are initially of shape + # (batch_size, seq_len, added_kv_proj_dim) before being projected + # to (batch_size, seq_len, cross_attention_dim). The norm is applied + # before the projection, so we need to use `added_kv_proj_dim` as + # the number of channels for the group norm. + norm_cross_num_channels = added_kv_proj_dim + else: + norm_cross_num_channels = self.cross_attention_dim + + self.norm_cross = nn.GroupNorm( + num_channels=norm_cross_num_channels, num_groups=cross_attention_norm_num_groups, eps=1e-5, affine=True + ) + else: + raise ValueError( + f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'" + ) + + self.to_qkv = QKVLinear(self.inner_dim, query_dim) + + self.added_proj_bias = added_proj_bias + if self.added_kv_proj_dim is not None: + self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias) + self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias) + if self.context_pre_only is not None: + self.add_q_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) + else: + self.add_q_proj = None + self.add_k_proj = None + self.add_v_proj = None + + if not self.pre_only: + self.to_out = nn.ModuleList([]) + self.to_out.append(nn.Linear(self.inner_dim, self.out_dim, bias=out_bias)) + self.to_out.append(nn.Dropout(dropout)) + else: + self.to_out = None + + if self.context_pre_only is not None and not self.context_pre_only: + self.to_add_out = nn.Linear(self.inner_dim, self.out_context_dim, bias=out_bias) + else: + self.to_add_out = None + + if qk_norm is not None and added_kv_proj_dim is not None: + if qk_norm == "fp32_layer_norm": + self.norm_added_q = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps) + self.norm_added_k = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps) + elif qk_norm == "rms_norm": + self.norm_added_q = RMSNorm(dim_head, eps=eps) + self.norm_added_k = RMSNorm(dim_head, eps=eps) + else: + raise ValueError( + f"unknown qk_norm: {qk_norm}. Should be one of `None,'layer_norm','fp32_layer_norm','rms_norm'`" + ) + else: + self.norm_added_q = None + self.norm_added_k = None + + self.set_processor(processor) + + def set_processor(self, processor: "AttnProcessor") -> None: + r""" + Set the attention processor to use. + + Args: + processor (`AttnProcessor`): + The attention processor to use. + """ + if ( + hasattr(self, "processor") + and isinstance(self.processor, torch.nn.Module) + and not isinstance(processor, torch.nn.Module) + ): + logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}") + self._modules.pop("processor") + + self.processor = processor + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + **cross_attention_kwargs, + ) -> torch.Tensor: + attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys()) + quiet_attn_parameters = {"ip_adapter_masks", "ip_hidden_states"} + unused_kwargs = [ + k for k, _ in cross_attention_kwargs.items() if k not in attn_parameters and k not in quiet_attn_parameters + ] + if len(unused_kwargs) > 0: + logger.warning( + f"cross_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored." + ) + cross_attention_kwargs = {k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters} + + return self.processor( + self, + hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + **cross_attention_kwargs, + ) + + def prepare_attention_mask( + self, attention_mask: torch.Tensor, target_length: int, batch_size: int, out_dim: int = 3 + ) -> torch.Tensor: + head_size = self.heads + if attention_mask is None: + return attention_mask + + current_length: int = attention_mask.shape[-1] + if current_length != target_length: + if attention_mask.device.type == "mps": + padding_shape = (attention_mask.shape[0], attention_mask.shape[1], target_length) + padding = torch.zeros(padding_shape, dtype=attention_mask.dtype, device=attention_mask.device) + attention_mask = torch.cat([attention_mask, padding], dim=2) + else: + attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) + + if out_dim == 3: + if attention_mask.shape[0] < batch_size * head_size: + attention_mask = attention_mask.repeat_interleave(head_size, dim=0) + elif out_dim == 4: + attention_mask = attention_mask.unsqueeze(1) + attention_mask = attention_mask.repeat_interleave(head_size, dim=1) + + return attention_mask + + +class CogVideoXAttnProcessor2_0: + r""" + Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on + query and key vectors, but does not include spatial normalization. + """ + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + text_seq_length = encoder_hidden_states.size(1) + + hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + B, S, _ = hidden_states.shape + qkv = attn.to_qkv(hidden_states) + inner_dim = qkv.shape[-1] // 3 + head_dim = inner_dim // attn.heads + qkv_shape = (B, S, 3, attn.heads, head_dim) + query, key, value = qkv.view(qkv_shape).permute(2, 0, 3, 1, 4).contiguous().unbind(0) + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + B, N, S, D = query.shape + dim = 48 + pad_shape = [B, N, S, D] + pad_shape[-1] = dim - pad_shape[-1] + pad = torch.zeros(pad_shape, dtype=query.dtype, device=query.device) + query = torch.cat([query, pad], dim=-1) + key = torch.cat([key, pad], dim=-1) + value = torch.cat([value, pad], dim=-1) + hidden_states = torch_npu.npu_prompt_flash_attention( + query, + key, + value, + input_layout='BNSD', + scale_value=D**-0.5, + pre_tokens=65535, + next_tokens=65535, + num_heads=N + ) + hidden_states = hidden_states[:, :, :, :D] + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + encoder_hidden_states, hidden_states = hidden_states.split( + [text_seq_length, hidden_states.size(1) - text_seq_length], dim=1 + ) + return hidden_states, encoder_hidden_states \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/models/model_load_utils.py b/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/models/model_load_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..34a462528398c0dc7550f58e67e84f0abe9e853a --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/models/model_load_utils.py @@ -0,0 +1,42 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright(C) 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import torch +import safetensors.torch + + +SAFETENSORS_EXTENSION = "safetensors" +EMA_STATE_DICT = "ema_state_dict" +STATE_DICT = "state_dict" +CPU = "cpu" + + +def load_state_dict_sd(model_path): + name = os.path.basename(model_path).split('.')[-1] # get weights name + if name.endswith("ckpt"): + weight = torch.load(model_path, map_location=CPU) + if (EMA_STATE_DICT in weight): + weight = weight[EMA_STATE_DICT] + weight = {key.replace("module.", ""): value for key, value in weight.items()} + elif STATE_DICT in weight: + weight = weight[STATE_DICT] + return weight + elif name == SAFETENSORS_EXTENSION: # diffuser model use same name + return safetensors.torch.load_file(model_path, device=CPU) # first load on cpu + else: + # to support hf shard model weights + return torch.load(model_path, map_location=CPU) # first load on cpu \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/models/modeling_utils.py b/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/models/modeling_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..aa8e33daaa56df927c7b0195e2cabc9d80a2ffd9 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/models/modeling_utils.py @@ -0,0 +1,771 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import itertools +import json +import os +import re +from collections import OrderedDict +from functools import wraps +from typing import Any, List, Optional, Tuple, Union + +import torch +from huggingface_hub.utils import validate_hf_hub_args +from torch import Tensor, nn + +from diffusers import __version__ +from diffusers.quantizers import DiffusersAutoQuantizer +from diffusers.quantizers.quantization_config import QuantizationMethod +from diffusers.utils import ( + CONFIG_NAME, + FLAX_WEIGHTS_NAME, + SAFETENSORS_WEIGHTS_NAME, + WEIGHTS_NAME, + _add_variant, + _get_checkpoint_shard_files, + _get_model_file, + deprecate, + is_accelerate_available, + is_bitsandbytes_version, + logging, +) +from diffusers.utils.hub_utils import PushToHubMixin +from diffusers.models.model_loading_utils import ( + _fetch_index_file, + _fetch_index_file_legacy, + _load_state_dict_into_model, + _merge_sharded_checkpoints, + load_model_dict_into_meta, + load_state_dict, +) + + +logger = logging.get_logger(__name__) + + +_LOW_CPU_MEM_USAGE_DEFAULT = True + + +if is_accelerate_available(): + import accelerate + + +def get_parameter_device(parameter: torch.nn.Module) -> torch.device: + try: + parameters_and_buffers = itertools.chain(parameter.parameters(), parameter.buffers()) + return next(parameters_and_buffers).device + except StopIteration: + # For torch.nn.DataParallel compatibility in PyTorch 1.5 + + def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]: + tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)] + return tuples + + gen = parameter._named_members(get_members_fn=find_tensor_attributes) + first_tuple = next(gen) + return first_tuple[1].device + + +def get_parameter_dtype(parameter: torch.nn.Module) -> torch.dtype: + """ + Returns the first found floating dtype in parameters if there is one, otherwise returns the last dtype it found. + """ + last_dtype = None + for param in parameter.parameters(): + last_dtype = param.dtype + if param.is_floating_point(): + return param.dtype + + for buffer in parameter.buffers(): + last_dtype = buffer.dtype + if buffer.is_floating_point(): + return buffer.dtype + + if last_dtype is not None: + # if no floating dtype was found return whatever the first dtype is + return last_dtype + + # For nn.DataParallel compatibility in PyTorch > 1.5 + def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]: + tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)] + return tuples + + gen = parameter._named_members(get_members_fn=find_tensor_attributes) + last_tuple = None + for current_tuple in gen: + last_tuple = current_tuple + if current_tuple[1].is_floating_point(): + return current_tuple[1].dtype + + if last_tuple is not None: + # fallback to the last dtype + return last_tuple[1].dtype + + +class ModelMixin(torch.nn.Module, PushToHubMixin): + config_name = CONFIG_NAME + _automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"] + _supports_gradient_checkpointing = False + _keys_to_ignore_on_load_unexpected = None + _no_split_modules = None + _keep_in_fp32_modules = None + + def __init__(self): + super().__init__() + + def __getattr__(self, name: str) -> Any: + + is_in_config = "_internal_dict" in self.__dict__ and hasattr(self.__dict__["_internal_dict"], name) + is_attribute = name in self.__dict__ + + if is_in_config and not is_attribute: + deprecation_message = f"Accessing config attribute `{name}` directly via '{type(self).__name__}' object attribute is deprecated. Please access '{name}' over '{type(self).__name__}'s config object instead, e.g. 'unet.config.{name}'." + deprecate("direct config name access", "1.0.0", deprecation_message, standard_warn=False, stacklevel=3) + return self._internal_dict[name] + + return super().__getattr__(name) + + @classmethod + @validate_hf_hub_args + def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs): + cache_dir = kwargs.pop("cache_dir", None) + ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False) + force_download = kwargs.pop("force_download", False) + from_flax = kwargs.pop("from_flax", False) + proxies = kwargs.pop("proxies", None) + output_loading_info = kwargs.pop("output_loading_info", False) + local_files_only = kwargs.pop("local_files_only", None) + token = kwargs.pop("token", None) + revision = kwargs.pop("revision", None) + torch_dtype = kwargs.pop("torch_dtype", None) + subfolder = kwargs.pop("subfolder", None) + device_map = kwargs.pop("device_map", None) + low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT) + variant = kwargs.pop("variant", None) + use_safetensors = kwargs.pop("use_safetensors", None) + quantization_config = kwargs.pop("quantization_config", None) + + allow_pickle = False + if use_safetensors is None: + use_safetensors = True + allow_pickle = True + + if low_cpu_mem_usage and not is_accelerate_available(): + low_cpu_mem_usage = False + logger.warning( + "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the" + " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install" + " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip" + " install accelerate\n```\n." + ) + + if device_map is not None and not is_accelerate_available(): + raise NotImplementedError( + "Loading and dispatching requires `accelerate`. Please make sure to install accelerate or set" + " `device_map=None`. You can install accelerate with `pip install accelerate`." + ) + + if low_cpu_mem_usage is False and device_map is not None: + raise ValueError( + f"You cannot set `low_cpu_mem_usage` to `False` while using device_map={device_map} for loading and" + " dispatching. Please make sure to set `low_cpu_mem_usage=True`." + ) + + if isinstance(device_map, torch.device): + device_map = {"": device_map} + elif isinstance(device_map, str) and device_map not in ["auto", "balanced", "balanced_low_0", "sequential"]: + try: + device_map = {"": torch.device(device_map)} + except RuntimeError as e: + raise ValueError( + "When passing device_map as a string, the value needs to be a device name (e.g. cpu, cuda:0) or " + f"'auto', 'balanced', 'balanced_low_0', 'sequential' but found {device_map}." + ) from e + elif isinstance(device_map, int): + if device_map < 0: + raise ValueError( + "You can't pass device_map as a negative int. If you want to put the model on the cpu, pass device_map = 'cpu' " + ) + else: + device_map = {"": device_map} + + if device_map is not None: + if low_cpu_mem_usage is None: + low_cpu_mem_usage = True + elif not low_cpu_mem_usage: + raise ValueError("Passing along a `device_map` requires `low_cpu_mem_usage=True`") + + # Load config if we don't provide a configuration + config_path = pretrained_model_name_or_path + + user_agent = { + "diffusers": __version__, + "file_type": "model", + "framework": "pytorch", + } + + # load config + config, unused_kwargs, commit_hash = cls.load_config( + config_path, + cache_dir=cache_dir, + return_unused_kwargs=True, + return_commit_hash=True, + force_download=force_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + **kwargs, + ) + # no in-place modification of the original config. + config = copy.deepcopy(config) + + # determine initial quantization config. + ####################################### + pre_quantized = "quantization_config" in config and config["quantization_config"] is not None + if pre_quantized or quantization_config is not None: + if pre_quantized: + config["quantization_config"] = DiffusersAutoQuantizer.merge_quantization_configs( + config["quantization_config"], quantization_config + ) + else: + config["quantization_config"] = quantization_config + hf_quantizer = DiffusersAutoQuantizer.from_config( + config["quantization_config"], pre_quantized=pre_quantized + ) + else: + hf_quantizer = None + + if hf_quantizer is not None: + is_bnb_quantization_method = hf_quantizer.quantization_config.quant_method.value == "bitsandbytes" + if is_bnb_quantization_method and device_map is not None: + raise NotImplementedError( + "Currently, `device_map` is automatically inferred for quantized bitsandbytes models. Support for providing `device_map` as an input will be added in the future." + ) + + hf_quantizer.validate_environment(torch_dtype=torch_dtype, from_flax=from_flax, device_map=device_map) + torch_dtype = hf_quantizer.update_torch_dtype(torch_dtype) + + # In order to ensure popular quantization methods are supported. Can be disable with `disable_telemetry` + user_agent["quant"] = hf_quantizer.quantization_config.quant_method.value + + # Force-set to `True` for more mem efficiency + if low_cpu_mem_usage is None: + low_cpu_mem_usage = True + logger.info("Set `low_cpu_mem_usage` to True as `hf_quantizer` is not None.") + elif not low_cpu_mem_usage: + raise ValueError("`low_cpu_mem_usage` cannot be False or None when using quantization.") + + # Check if `_keep_in_fp32_modules` is not None + use_keep_in_fp32_modules = (cls._keep_in_fp32_modules is not None) and ( + (torch_dtype == torch.float16) or hasattr(hf_quantizer, "use_keep_in_fp32_modules") + ) + if use_keep_in_fp32_modules: + keep_in_fp32_modules = cls._keep_in_fp32_modules + if not isinstance(keep_in_fp32_modules, list): + keep_in_fp32_modules = [keep_in_fp32_modules] + + if low_cpu_mem_usage is None: + low_cpu_mem_usage = True + logger.info("Set `low_cpu_mem_usage` to True as `_keep_in_fp32_modules` is not None.") + elif not low_cpu_mem_usage: + raise ValueError("`low_cpu_mem_usage` cannot be False when `keep_in_fp32_modules` is True.") + else: + keep_in_fp32_modules = [] + ####################################### + + # Determine if we're loading from a directory of sharded checkpoints. + is_sharded = False + index_file = None + is_local = os.path.isdir(pretrained_model_name_or_path) + index_file_kwargs = { + "is_local": is_local, + "pretrained_model_name_or_path": pretrained_model_name_or_path, + "subfolder": subfolder or "", + "use_safetensors": use_safetensors, + "cache_dir": cache_dir, + "variant": variant, + "force_download": force_download, + "proxies": proxies, + "local_files_only": local_files_only, + "token": token, + "revision": revision, + "user_agent": user_agent, + "commit_hash": commit_hash, + } + index_file = _fetch_index_file(**index_file_kwargs) + # In case the index file was not found we still have to consider the legacy format. + # this becomes applicable when the variant is not None. + if variant is not None and (index_file is None or not os.path.exists(index_file)): + index_file = _fetch_index_file_legacy(**index_file_kwargs) + if index_file is not None and index_file.is_file(): + is_sharded = True + + if is_sharded and from_flax: + raise ValueError("Loading of sharded checkpoints is not supported when `from_flax=True`.") + + # load model + model_file = None + if from_flax: + model_file = _get_model_file( + pretrained_model_name_or_path, + weights_name=FLAX_WEIGHTS_NAME, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + commit_hash=commit_hash, + ) + model = cls.from_config(config, **unused_kwargs) + + # Convert the weights + from .modeling_pytorch_flax_utils import load_flax_checkpoint_in_pytorch_model + + model = load_flax_checkpoint_in_pytorch_model(model, model_file) + else: + if is_sharded: + sharded_ckpt_cached_folder, sharded_metadata = _get_checkpoint_shard_files( + pretrained_model_name_or_path, + index_file, + cache_dir=cache_dir, + proxies=proxies, + local_files_only=local_files_only, + token=token, + user_agent=user_agent, + revision=revision, + subfolder=subfolder or "", + ) + if hf_quantizer is not None and is_bnb_quantization_method: + model_file = _merge_sharded_checkpoints(sharded_ckpt_cached_folder, sharded_metadata) + logger.info("Merged sharded checkpoints as `hf_quantizer` is not None.") + is_sharded = False + + elif use_safetensors and not is_sharded: + try: + model_file = _get_model_file( + pretrained_model_name_or_path, + weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant), + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + commit_hash=commit_hash, + ) + + except IOError as e: + logger.error(f"An error occurred while trying to fetch {pretrained_model_name_or_path}: {e}") + if not allow_pickle: + raise + logger.warning( + "Defaulting to unsafe serialization. Pass `allow_pickle=False` to raise an error instead." + ) + + if model_file is None and not is_sharded: + model_file = _get_model_file( + pretrained_model_name_or_path, + weights_name=_add_variant(WEIGHTS_NAME, variant), + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + commit_hash=commit_hash, + ) + + if low_cpu_mem_usage: + # Instantiate model with empty weights + with accelerate.init_empty_weights(): + model = cls.from_config(config, **unused_kwargs) + + if hf_quantizer is not None: + hf_quantizer.preprocess_model( + model=model, device_map=device_map, keep_in_fp32_modules=keep_in_fp32_modules + ) + + # if device_map is None, load the state dict and move the params from meta device to the cpu + if device_map is None and not is_sharded: + # `torch.cuda.current_device()` is fine here when `hf_quantizer` is not None. + # It would error out during the `validate_environment()` call above in the absence of cuda. + if hf_quantizer is None: + param_device = "cpu" + else: + param_device = torch.device(torch.cuda.current_device()) + state_dict = load_state_dict(model_file, variant=variant) + model._convert_deprecated_attention_blocks(state_dict) + + # move the params from meta device to cpu + missing_keys = set(model.state_dict().keys()) - set(state_dict.keys()) + if hf_quantizer is not None: + missing_keys = hf_quantizer.update_missing_keys(model, missing_keys, prefix="") + if len(missing_keys) > 0: + raise ValueError( + f"Cannot load {cls} from {pretrained_model_name_or_path} because the following keys are" + f" missing: \n {', '.join(missing_keys)}. \n Please make sure to pass" + " `low_cpu_mem_usage=False` and `device_map=None` if you want to randomly initialize" + " those weights or else make sure your checkpoint file is correct." + ) + + unexpected_keys = load_model_dict_into_meta( + model, + state_dict, + device=param_device, + dtype=torch_dtype, + model_name_or_path=pretrained_model_name_or_path, + hf_quantizer=hf_quantizer, + keep_in_fp32_modules=keep_in_fp32_modules, + ) + + if cls._keys_to_ignore_on_load_unexpected is not None: + for pat in cls._keys_to_ignore_on_load_unexpected: + unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None] + + if len(unexpected_keys) > 0: + logger.warning( + f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}" + ) + + else: + weights_path = index_file + with open(index_file) as f: + index = json.loads(f.read()) + if "weight_map" in index: + index = index["weight_map"] + weights_path = sorted(list(set(index.values()))) + weights_path = [os.path.join(pretrained_model_name_or_path, f) for f in weights_path] + + model = cls._load_model(model, weights_path, is_sharded) + + loading_info = { + "missing_keys": [], + "unexpected_keys": [], + "mismatched_keys": [], + "error_msgs": [], + } + else: + model = cls.from_config(config, **unused_kwargs) + + state_dict = load_state_dict(model_file, variant=variant) + model._convert_deprecated_attention_blocks(state_dict) + + model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model( + model, + state_dict, + model_file, + pretrained_model_name_or_path, + ignore_mismatched_sizes=ignore_mismatched_sizes, + ) + + loading_info = { + "missing_keys": missing_keys, + "unexpected_keys": unexpected_keys, + "mismatched_keys": mismatched_keys, + "error_msgs": error_msgs, + } + + if hf_quantizer is not None: + hf_quantizer.postprocess_model(model) + model.hf_quantizer = hf_quantizer + + if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype): + raise ValueError( + f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}." + ) + # When using `use_keep_in_fp32_modules` if we do a global `to()` here, then we will + # completely lose the effectivity of `use_keep_in_fp32_modules`. + elif torch_dtype is not None and hf_quantizer is None and not use_keep_in_fp32_modules: + model = model.to(torch_dtype) + + if hf_quantizer is not None: + # We also make sure to purge `_pre_quantization_dtype` when we serialize + # the model config because `_pre_quantization_dtype` is `torch.dtype`, not JSON serializable. + model.register_to_config(_name_or_path=pretrained_model_name_or_path, _pre_quantization_dtype=torch_dtype) + else: + model.register_to_config(_name_or_path=pretrained_model_name_or_path) + + # Set model in evaluation mode to deactivate DropOut modules by default + model.eval() + if output_loading_info: + return model, loading_info + + return model + + @classmethod + def _load_model(cls, model, weights_path, is_sharded): + if not is_sharded: + state_dict = load_state_dict(weights_path) + model.load_weights(state_dict) + else: + need_key = set(model.state_dict().keys()) + state_dict = {} + cache = {} + for weight_file in weights_path: + state_dict = load_state_dict(weight_file) + state_dict.update(cache) + loadkey_cache = model.load_weights(state_dict, is_sharded) + if loadkey_cache : + if isinstance(loadkey_cache, tuple): + loaded_keys, cache = loadkey_cache + else: + loaded_keys = loadkey_cache + need_key = need_key.symmetric_difference(set(loaded_keys)) + + if len(need_key) > 0: + raise ValueError(f"The weight miss key: {need_key}") + return model + + def load_weights(self, state_dict, shard=False): + with torch.no_grad(): + if not shard: + self.load_state_dict(state_dict) + return {} + else: + self.load_state_dict(state_dict, strict=False, assign=True) + return state_dict.keys() + + # Adapted from `transformers`. + @wraps(torch.nn.Module.cuda) + def cuda(self, *args, **kwargs): + # Checks if the model has been loaded in 4-bit or 8-bit with BNB + if getattr(self, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES: + if getattr(self, "is_loaded_in_8bit", False): + raise ValueError( + "Calling `cuda()` is not supported for `8-bit` quantized models. " + " Please use the model as it is, since the model has already been set to the correct devices." + ) + elif is_bitsandbytes_version("<", "0.43.2"): + raise ValueError( + "Calling `cuda()` is not supported for `4-bit` quantized models with the installed version of bitsandbytes. " + f"The current device is `{self.device}`. If you intended to move the model, please install bitsandbytes >= 0.43.2." + ) + return super().cuda(*args, **kwargs) + + # Adapted from `transformers`. + @wraps(torch.nn.Module.to) + def to(self, *args, **kwargs): + dtype_present_in_args = "dtype" in kwargs + + if not dtype_present_in_args: + for arg in args: + if isinstance(arg, torch.dtype): + dtype_present_in_args = True + break + + if getattr(self, "is_quantized", False): + if dtype_present_in_args: + raise ValueError( + "Casting a quantized model to a new `dtype` is unsupported. To set the dtype of unquantized layers, please " + "use the `torch_dtype` argument when loading the model using `from_pretrained` or `from_single_file`" + ) + + if getattr(self, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES: + if getattr(self, "is_loaded_in_8bit", False): + raise ValueError( + "`.to` is not supported for `8-bit` bitsandbytes models. Please use the model as it is, since the" + " model has already been set to the correct devices and casted to the correct `dtype`." + ) + elif is_bitsandbytes_version("<", "0.43.2"): + raise ValueError( + "Calling `to()` is not supported for `4-bit` quantized models with the installed version of bitsandbytes. " + f"The current device is `{self.device}`. If you intended to move the model, please install bitsandbytes >= 0.43.2." + ) + return super().to(*args, **kwargs) + + # Taken from `transformers`. + def half(self, *args): + # Checks if the model is quantized + if getattr(self, "is_quantized", False): + raise ValueError( + "`.half()` is not supported for quantized model. Please use the model as it is, since the" + " model has already been cast to the correct `dtype`." + ) + else: + return super().half(*args) + + # Taken from `transformers`. + def float(self, *args): + # Checks if the model is quantized + if getattr(self, "is_quantized", False): + raise ValueError( + "`.float()` is not supported for quantized model. Please use the model as it is, since the" + " model has already been cast to the correct `dtype`." + ) + else: + return super().float(*args) + + @classmethod + def _load_pretrained_model( + cls, + model, + state_dict: OrderedDict, + pretrained_model_name_or_path: Union[str, os.PathLike], + ignore_mismatched_sizes: bool = False, + ): + # Retrieve missing & unexpected_keys + model_state_dict = model.state_dict() + loaded_keys = list(state_dict.keys()) + + expected_keys = list(model_state_dict.keys()) + + original_loaded_keys = loaded_keys + + missing_keys = list(set(expected_keys) - set(loaded_keys)) + unexpected_keys = list(set(loaded_keys) - set(expected_keys)) + + # Make sure we are able to load base models as well as derived models (with heads) + model_to_load = model + + def _find_mismatched_keys( + state_dict, + model_state_dict, + loaded_keys, + ignore_mismatched_sizes, + ): + mismatched_keys = [] + if ignore_mismatched_sizes: + for checkpoint_key in loaded_keys: + model_key = checkpoint_key + + if ( + model_key in model_state_dict + and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape + ): + mismatched_keys.append( + (checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape) + ) + del state_dict[checkpoint_key] + return mismatched_keys + + if state_dict is not None: + # Whole checkpoint + mismatched_keys = _find_mismatched_keys( + state_dict, + model_state_dict, + original_loaded_keys, + ignore_mismatched_sizes, + ) + error_msgs = _load_state_dict_into_model(model_to_load, state_dict) + + if len(error_msgs) > 0: + error_msg = "\n\t".join(error_msgs) + if "size mismatch" in error_msg: + error_msg += ( + "\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method." + ) + raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}") + + if len(unexpected_keys) > 0: + logger.warning( + f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when" + f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are" + f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task" + " or with another architecture (e.g. initializing a BertForSequenceClassification model from a" + " BertForPreTraining model).\n- This IS NOT expected if you are initializing" + f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly" + " identical (initializing a BertForSequenceClassification model from a" + " BertForSequenceClassification model)." + ) + else: + logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n") + if len(missing_keys) > 0: + logger.warning( + f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at" + f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably" + " TRAIN this model on a down-stream task to be able to use it for predictions and inference." + ) + elif len(mismatched_keys) == 0: + logger.info( + f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at" + f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the" + f" checkpoint was trained on, you can already use {model.__class__.__name__} for predictions" + " without further training." + ) + if len(mismatched_keys) > 0: + mismatched_warning = "\n".join( + [ + f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated" + for key, shape1, shape2 in mismatched_keys + ] + ) + logger.warning( + f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at" + f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not" + f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be" + " able to use it for predictions and inference." + ) + + return model, missing_keys, unexpected_keys, mismatched_keys, error_msgs + + @property + def device(self) -> torch.device: + return get_parameter_device(self) + + @property + def dtype(self) -> torch.dtype: + return get_parameter_dtype(self) + + def _convert_deprecated_attention_blocks(self, state_dict: OrderedDict) -> None: + deprecated_attention_block_paths = [] + + def recursive_find_attn_block(name, module): + if hasattr(module, "_from_deprecated_attn_block") and module._from_deprecated_attn_block: + deprecated_attention_block_paths.append(name) + + for sub_name, sub_module in module.named_children(): + sub_name = sub_name if name == "" else f"{name}.{sub_name}" + recursive_find_attn_block(sub_name, sub_module) + + recursive_find_attn_block("", self) + + for path in deprecated_attention_block_paths: + # group_norm path stays the same + + # query -> to_q + if f"{path}.query.weight" in state_dict: + state_dict[f"{path}.to_q.weight"] = state_dict.pop(f"{path}.query.weight") + if f"{path}.query.bias" in state_dict: + state_dict[f"{path}.to_q.bias"] = state_dict.pop(f"{path}.query.bias") + + # key -> to_k + if f"{path}.key.weight" in state_dict: + state_dict[f"{path}.to_k.weight"] = state_dict.pop(f"{path}.key.weight") + if f"{path}.key.bias" in state_dict: + state_dict[f"{path}.to_k.bias"] = state_dict.pop(f"{path}.key.bias") + + # value -> to_v + if f"{path}.value.weight" in state_dict: + state_dict[f"{path}.to_v.weight"] = state_dict.pop(f"{path}.value.weight") + if f"{path}.value.bias" in state_dict: + state_dict[f"{path}.to_v.bias"] = state_dict.pop(f"{path}.value.bias") + + # proj_attn -> to_out.0 + if f"{path}.proj_attn.weight" in state_dict: + state_dict[f"{path}.to_out.0.weight"] = state_dict.pop(f"{path}.proj_attn.weight") + if f"{path}.proj_attn.bias" in state_dict: + state_dict[f"{path}.to_out.0.bias"] = state_dict.pop(f"{path}.proj_attn.bias") diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/models/transformer_cogview3plus.py b/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/models/transformer_cogview3plus.py new file mode 100644 index 0000000000000000000000000000000000000000..f704e22589437bb8533e428f8cf1bbab7375f2ae --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/models/transformer_cogview3plus.py @@ -0,0 +1,397 @@ +# Copyright 2024 The CogView team, Tsinghua University & ZhipuAI and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict, Union + +import torch +import torch.nn as nn +import numpy as np + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.models.attention_processor import AttentionProcessor +from diffusers.utils import logging +from diffusers.models.modeling_outputs import Transformer2DModelOutput + +from .modeling_utils import ModelMixin +from .attention import FeedForward +from .attention_processor import CogVideoXAttnProcessor2_0, Attention +from ..layers import CogView3PlusAdaLayerNormZeroTextImage, AdaLayerNormContinuous +from ..layers import CogView3CombinedTimestepSizeEmbeddings, CogView3PlusPatchEmbed + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class CogView3PlusTransformerBlock(nn.Module): + def __init__( + self, + dim: int = 2560, + num_attention_heads: int = 64, + attention_head_dim: int = 40, + time_embed_dim: int = 512, + ): + super().__init__() + + self.norm1 = CogView3PlusAdaLayerNormZeroTextImage(embedding_dim=time_embed_dim, dim=dim) + + self.attn1 = Attention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + out_dim=dim, + bias=True, + qk_norm="layer_norm", + elementwise_affine=False, + eps=1e-6, + processor=CogVideoXAttnProcessor2_0(), + ) + + self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5) + self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5) + + self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + emb: torch.Tensor, + ) -> torch.Tensor: + text_seq_length = encoder_hidden_states.size(1) + + # norm & modulate + norm_hidden_states, chunk_params = self.norm1(hidden_states, encoder_hidden_states, emb) + + gate_msa = chunk_params.gate_msa + shift_mlp = chunk_params.shift_mlp + scale_mlp = chunk_params.scale_mlp + gate_mlp = chunk_params.gate_mlp + norm_encoder_hidden_states = chunk_params.context + c_gate_msa = chunk_params.c_gate_msa + c_shift_mlp = chunk_params.c_shift_mlp + c_scale_mlp = chunk_params.c_scale_mlp + c_gate_mlp = chunk_params.c_gate_mlp + + # attention + attn_hidden_states, attn_encoder_hidden_states = self.attn1( + hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states + ) + + hidden_states = hidden_states + gate_msa.unsqueeze(1) * attn_hidden_states + encoder_hidden_states = encoder_hidden_states + c_gate_msa.unsqueeze(1) * attn_encoder_hidden_states + + # norm & modulate + norm_hidden_states = self.norm2(hidden_states) + norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + + norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) + norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None] + + # feed-forward + norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1) + ff_output = self.ff(norm_hidden_states) + + hidden_states = hidden_states + gate_mlp.unsqueeze(1) * ff_output[:, text_seq_length:] + encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * ff_output[:, :text_seq_length] + + if hidden_states.dtype == torch.float16: + hidden_states = hidden_states.clip(-65504, 65504) + if encoder_hidden_states.dtype == torch.float16: + encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504) + return hidden_states, encoder_hidden_states + + +class CogView3PlusTransformer2DModel(ModelMixin, ConfigMixin): + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + patch_size: int = 2, + in_channels: int = 16, + num_layers: int = 30, + attention_head_dim: int = 40, + num_attention_heads: int = 64, + out_channels: int = 16, + text_embed_dim: int = 4096, + time_embed_dim: int = 512, + condition_dim: int = 256, + pos_embed_max_size: int = 128, + use_cache: bool = True, + cache_interval: int = 2, + cache_start: int = 3, + num_cache_layer: int = 13, + cache_start_steps: int = 5, + ): + super().__init__() + self.out_channels = out_channels + self.inner_dim = num_attention_heads * attention_head_dim + self.num_layers = num_layers + + # CogView3 uses 3 additional SDXL-like conditions - original_size, target_size, crop_coords + # Each of these are sincos embeddings of shape 2 * condition_dim + self.pooled_projection_dim = 3 * 2 * condition_dim + + self.patch_embed = CogView3PlusPatchEmbed( + in_channels=in_channels, + hidden_size=self.inner_dim, + patch_size=patch_size, + text_hidden_size=text_embed_dim, + pos_embed_max_size=pos_embed_max_size, + ) + + self.time_condition_embed = CogView3CombinedTimestepSizeEmbeddings( + embedding_dim=time_embed_dim, + condition_dim=condition_dim, + pooled_projection_dim=self.pooled_projection_dim, + timesteps_dim=self.inner_dim, + ) + + self.transformer_blocks = nn.ModuleList( + [ + CogView3PlusTransformerBlock( + dim=self.inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + time_embed_dim=time_embed_dim, + ) + for _ in range(num_layers) + ] + ) + + self.norm_out = AdaLayerNormContinuous( + embedding_dim=self.inner_dim, + conditioning_embedding_dim=time_embed_dim, + elementwise_affine=False, + eps=1e-6, + ) + self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True) + + self.gradient_checkpointing = False + + self.q_weight_cache = None + self.q_bias_cache = None + self.k_weight_cache = None + self.k_bias_cache = None + self.v_weight_cache = None + self.v_bias_cache = None + + self.use_cache = use_cache + self.cache_interval = cache_interval + self.cache_start = cache_start + self.num_cache_layer = num_cache_layer + self.cache_start_steps = cache_start_steps + + self.delta_cache = None + self.delta_encoder_cache = None + + @property + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): + if hasattr(module, "get_processor"): + processors[f"{name}.processor"] = module.get_processor() + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + def _set_gradient_checkpointing(self, module, value=False): + if hasattr(module, "gradient_checkpointing"): + module.gradient_checkpointing = value + + def forward( + self, + states, + timestep: torch.LongTensor, + original_size: torch.Tensor, + target_size: torch.Tensor, + crop_coords: torch.Tensor, + ) -> Union[torch.Tensor, Transformer2DModelOutput]: + hidden_states = states[0] + encoder_hidden_states = states[1] + height, width = hidden_states.shape[-2:] + text_seq_length = encoder_hidden_states.shape[1] + + hidden_states = self.patch_embed( + hidden_states, encoder_hidden_states + ) # takes care of adding positional embeddings too. + emb = self.time_condition_embed(timestep, original_size, target_size, crop_coords, hidden_states.dtype) + + encoder_hidden_states = hidden_states[:, :text_seq_length] + hidden_states = hidden_states[:, text_seq_length:] + + hidden_states, encoder_hidden_states = self._forward_blocks(hidden_states, encoder_hidden_states, emb, states[2]) + + hidden_states = self.norm_out(hidden_states, emb) + hidden_states = self.proj_out(hidden_states) # (batch_size, height*width, patch_size*patch_size*out_channels) + + # unpatchify + patch_size = self.config.patch_size + height = height // patch_size + width = width // patch_size + + hidden_states = hidden_states.reshape( + shape=(hidden_states.shape[0], height, width, self.out_channels, patch_size, patch_size) + ) + hidden_states = torch.einsum("nhwcpq->nchpwq", hidden_states) + output = hidden_states.reshape( + shape=(hidden_states.shape[0], self.out_channels, height * patch_size, width * patch_size) + ) + + return Transformer2DModelOutput(sample=output) + + # forward blocks in range [start_idx, end_idx), then return input and output + def _forward_blocks_range(self, hidden_states, encoder_hidden_states, emb, start_idx, end_idx, **kwargs): + for _, block in enumerate(self.transformer_blocks[start_idx: end_idx]): + hidden_states, encoder_hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + emb=emb, + ) + + return hidden_states, encoder_hidden_states + + def _forward_blocks(self, hidden_states, encoder_hidden_states, emb, t_idx): + num_blocks = len(self.transformer_blocks) + + if not self.use_cache or (t_idx < self.cache_start_steps): + hidden_states, encoder_hidden_states = self._forward_blocks_range( + hidden_states, + encoder_hidden_states, + emb, + 0, + num_blocks + ) + else: + # infer [0, cache_start) + hidden_states, encoder_hidden_states = self._forward_blocks_range( + hidden_states, + encoder_hidden_states, + emb, + 0, + self.cache_start + ) + # infer [cache_start, cache_end) + cache_end = np.minimum(self.cache_start + self.num_cache_layer, num_blocks) + hidden_states_before_cache = hidden_states.clone() + encoder_hidden_states_before_cache = encoder_hidden_states.clone() + if t_idx % self.cache_interval == (self.cache_start_steps % self.cache_interval): + hidden_states, encoder_hidden_states = self._forward_blocks_range( + hidden_states, + encoder_hidden_states, + emb, + self.cache_start, + cache_end + ) + self.delta_cache = hidden_states - hidden_states_before_cache + self.delta_encoder_cache = encoder_hidden_states - encoder_hidden_states_before_cache + else: + hidden_states = hidden_states_before_cache + self.delta_cache + encoder_hidden_states = encoder_hidden_states_before_cache + self.delta_encoder_cache + # infer [cache_end, num_blocks) + hidden_states, encoder_hidden_states = self._forward_blocks_range( + hidden_states, + encoder_hidden_states, + emb, + cache_end, + num_blocks + ) + + return hidden_states, encoder_hidden_states + + def load_weights(self, state_dict, shard=False): + with torch.no_grad(): + if not shard: + self.load_state_dict(state_dict) + return {} + else: + weights = state_dict + + for i in range(self.num_layers): + if i != 26: + q_weight = weights.pop(f"transformer_blocks.{i}.attn1.to_q.weight", None) + q_bias = weights.pop(f"transformer_blocks.{i}.attn1.to_q.bias", None) + k_weight = weights.pop(f"transformer_blocks.{i}.attn1.to_k.weight", None) + k_bias = weights.pop(f"transformer_blocks.{i}.attn1.to_k.bias", None) + v_weight = weights.pop(f"transformer_blocks.{i}.attn1.to_v.weight", None) + v_bias = weights.pop(f"transformer_blocks.{i}.attn1.to_v.bias", None) + + # query, key, value的weight和bias权重存在同一个文件中,不会分开存储。 + if q_weight is not None and k_weight is not None and v_weight is not None: + qkv_weight = torch.cat([q_weight, k_weight, v_weight], dim=0).transpose(0, 1).contiguous() + qkv_bias = torch.cat([q_bias, k_bias, v_bias], dim=0).contiguous() + weights[f"transformer_blocks.{i}.attn1.to_qkv.weight"] = qkv_weight + weights[f"transformer_blocks.{i}.attn1.to_qkv.bias"] = qkv_bias + else: + if self.q_weight_cache is None: + self.q_weight_cache = weights.pop(f"transformer_blocks.{i}.attn1.to_q.weight", None) + if self.q_bias_cache is None: + self.q_bias_cache = weights.pop(f"transformer_blocks.{i}.attn1.to_q.bias", None) + if self.k_weight_cache is None: + self.k_weight_cache = weights.pop(f"transformer_blocks.{i}.attn1.to_k.weight", None) + if self.k_bias_cache is None: + self.k_bias_cache = weights.pop(f"transformer_blocks.{i}.attn1.to_k.bias", None) + if self.v_weight_cache is None: + self.v_weight_cache = weights.pop(f"transformer_blocks.{i}.attn1.to_v.weight", None) + if self.v_bias_cache is None: + self.v_bias_cache = weights.pop(f"transformer_blocks.{i}.attn1.to_v.bias", None) + + qk_weight_cache = self.q_weight_cache is not None and self.k_weight_cache is not None + if qk_weight_cache and self.v_weight_cache is not None: + qkv_weight = torch.cat( + [self.q_weight_cache, self.k_weight_cache, self.v_weight_cache], + dim=0 + ).transpose(0, 1).contiguous() + qkv_bias = torch.cat([self.q_bias_cache, self.k_bias_cache, self.v_bias_cache], dim=0).contiguous() + weights[f"transformer_blocks.26.attn1.to_qkv.weight"] = qkv_weight + weights[f"transformer_blocks.26.attn1.to_qkv.bias"] = qkv_bias + + self.load_state_dict(weights, strict=False, assign=True) + return weights.keys() diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/pipeline/__init__.py b/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/pipeline/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..626e0d588b9be2cd200013db7680e73655452afe --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/pipeline/__init__.py @@ -0,0 +1 @@ +from .pipeline_cogview3plus import CogView3PlusPipeline, DiffusionPipeline \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/pipeline/pipeline_cogview3plus.py b/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/pipeline/pipeline_cogview3plus.py new file mode 100644 index 0000000000000000000000000000000000000000..fe2bd5cfcd33a2a7e99cdf4f79b1130dc86e5cf5 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/pipeline/pipeline_cogview3plus.py @@ -0,0 +1,339 @@ +# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import List, Optional, Tuple, Union + +import torch +from transformers import T5EncoderModel, T5Tokenizer + +from diffusers.image_processor import VaeImageProcessor +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.utils import logging +from diffusers.utils.torch_utils import randn_tensor +from diffusers import AutoencoderKL + +from ..models import CogView3PlusTransformer2DModel +from ..schedulers import CogVideoXDDIMScheduler +from .pipeline_output import CogView3PipelineOutput + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class CogView3PlusPipeline(DiffusionPipeline): + _optional_components = [] + model_cpu_offload_seq = "text_encoder->transformer->vae" + + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + ] + + def __init__( + self, + tokenizer: T5Tokenizer, + text_encoder: T5EncoderModel, + vae: AutoencoderKL, + transformer: CogView3PlusTransformer2DModel, + scheduler: CogVideoXDDIMScheduler, + ): + super().__init__() + + self.register_modules( + tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler + ) + self.vae_scale_factor = ( + 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 + ) + + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + num_images_per_prompt: int = 1, + max_sequence_length: int = 226, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder(text_input_ids.to(device))[0] + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + return prompt_embeds + + def encode_prompt( + self, + prompt: Union[str, List[str]], + num_images_per_prompt: int = 1, + max_sequence_length: int = 224, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + negative_prompt_embeds = prompt_embeds.new_zeros(prompt_embeds.shape) + + return prompt_embeds, negative_prompt_embeds + + def prepare_latents(self, batch_size, num_channels_latents, image_size, dtype, device): + height = image_size[0] + width = image_size[1] + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + latents = randn_tensor(shape, device=device, dtype=dtype) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def prepare_extra_step_kwargs(self, generator, eta): + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + height, + width, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + def __call__( + self, + prompt: Optional[Union[str, List[str]]] = None, + image_size: Tuple[int, int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 5.0, + num_images_per_prompt: int = 1, + ) -> Union[CogView3PipelineOutput, Tuple]: + if image_size is None: + height = self.transformer.config.sample_size * self.vae_scale_factor + width = self.transformer.config.sample_size * self.vae_scale_factor + else: + height = image_size[0] + width = image_size[1] + + original_size = (height, width) + target_size = (height, width) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + ) + self._guidance_scale = guidance_scale + self._interrupt = False + + # 2. Default call parameters + if isinstance(prompt, str): + batch_size = 1 + else: + batch_size = len(prompt) + + device = self._execution_device + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=224, + device=device, + ) + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + + # 4. Prepare timesteps + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device) + self._num_timesteps = len(timesteps) + + # 5. Prepare latents. + latent_channels = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + latent_channels, + (height, width), + prompt_embeds.dtype, + device, + ) + + extra_step_kwargs = self.prepare_extra_step_kwargs(None, 0.0) + + # 7. Prepare additional timestep conditions + original_size = torch.tensor([original_size], dtype=prompt_embeds.dtype) + target_size = torch.tensor([target_size], dtype=prompt_embeds.dtype) + crops_coords_top_left = torch.tensor([(0, 0)], dtype=prompt_embeds.dtype) + + if self.do_classifier_free_guidance: + original_size = torch.cat([original_size, original_size]) + target_size = torch.cat([target_size, target_size]) + crops_coords_top_left = torch.cat([crops_coords_top_left, crops_coords_top_left]) + + original_size = original_size.to(device).repeat(batch_size * num_images_per_prompt, 1) + target_size = target_size.to(device).repeat(batch_size * num_images_per_prompt, 1) + crops_coords_top_left = crops_coords_top_left.to(device).repeat(batch_size * num_images_per_prompt, 1) + + # 8. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + # for DPM-solver++ + old_pred_original_sample = None + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]) + + # predict noise model_output + noise_pred = self.transformer( + states=(latent_model_input, prompt_embeds, i), + timestep=timestep, + original_size=original_size, + target_size=target_size, + crop_coords=crops_coords_top_left, + )[0] + noise_pred = noise_pred.float() + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + latents = latents.to(prompt_embeds.dtype) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=None)[0] + image = self.image_processor.postprocess(image, output_type="pil") + + # Offload all models + self.maybe_free_model_hooks() + + return CogView3PipelineOutput(images=image) \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/pipeline/pipeline_output.py b/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/pipeline/pipeline_output.py new file mode 100644 index 0000000000000000000000000000000000000000..11f8976f0e5ab01333d91c8d452ae25f2e798eae --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/pipeline/pipeline_output.py @@ -0,0 +1,21 @@ +from dataclasses import dataclass +from typing import List, Union + +import numpy as np +import PIL.Image + +from diffusers.utils import BaseOutput + + +@dataclass +class CogView3PipelineOutput(BaseOutput): + """ + Output class for CogView3 pipelines. + + Args: + images (`List[PIL.Image.Image]` or `np.ndarray`) + List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, + num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. + """ + + images: Union[List[PIL.Image.Image], np.ndarray] \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/schedulers/__init__.py b/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/schedulers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7a8f559a28aee98d7c3724b0b674f7f8e1230a63 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/schedulers/__init__.py @@ -0,0 +1,2 @@ +from .scheduling_ddim_cogvideox import CogVideoXDDIMScheduler +from .scheduling_utils import SchedulerMixin \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/schedulers/scheduling_ddim_cogvideox.py b/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/schedulers/scheduling_ddim_cogvideox.py new file mode 100644 index 0000000000000000000000000000000000000000..b3f6ce229b251658a2efaf4ee113942898735017 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/schedulers/scheduling_ddim_cogvideox.py @@ -0,0 +1,276 @@ +# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.utils import BaseOutput +from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin + + +@dataclass +class DDIMSchedulerOutput(BaseOutput): + prev_sample: torch.Tensor + pred_original_sample: Optional[torch.Tensor] = None + + +def betas_for_alpha_bar( + num_diffusion_timesteps, + max_beta=0.999, + alpha_transform_type="cosine", +): + if alpha_transform_type == "cosine": + + def alpha_bar_fn(t): + return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 + + elif alpha_transform_type == "exp": + + def alpha_bar_fn(t): + return math.exp(t * -12.0) + + else: + raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}") + + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) + return torch.tensor(betas, dtype=torch.float32) + + +def rescale_zero_terminal_snr(alphas_cumprod): + alphas_bar_sqrt = alphas_cumprod.sqrt() + + # Store old values. + alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() + alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() + + # Shift so the last timestep is zero. + alphas_bar_sqrt -= alphas_bar_sqrt_T + + # Scale so the first timestep is back to the old value. + alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) + + # Convert alphas_bar_sqrt to betas + alphas_bar = alphas_bar_sqrt**2 # Revert sqrt + + return alphas_bar + + +class CogVideoXDDIMScheduler(SchedulerMixin, ConfigMixin): + _compatibles = [e.name for e in KarrasDiffusionSchedulers] + order = 1 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + beta_start: float = 0.00085, + beta_end: float = 0.0120, + beta_schedule: str = "scaled_linear", + trained_betas: Optional[Union[np.ndarray, List[float]]] = None, + set_alpha_to_one: bool = True, + rescale_betas_zero_snr: bool = False, + snr_shift_scale: float = 3.0, + ): + if trained_betas is not None: + self.betas = torch.tensor(trained_betas, dtype=torch.float32) + elif beta_schedule == "linear": + self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) + elif beta_schedule == "scaled_linear": + # this schedule is very specific to the latent diffusion model. + self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float64) ** 2 + elif beta_schedule == "squaredcos_cap_v2": + # Glide cosine schedule + self.betas = betas_for_alpha_bar(num_train_timesteps) + else: + raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}") + + self.alphas = 1.0 - self.betas + self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) + + # Modify: SNR shift following SD3 + self.alphas_cumprod = self.alphas_cumprod / (snr_shift_scale + (1 - snr_shift_scale) * self.alphas_cumprod) + + # Rescale for zero SNR + if rescale_betas_zero_snr: + self.alphas_cumprod = rescale_zero_terminal_snr(self.alphas_cumprod) + + self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0] + + # standard deviation of the initial noise distribution + self.init_noise_sigma = 1.0 + + # setable values + self.num_inference_steps = None + self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64)) + + def _get_variance(self, timestep, prev_timestep): + alpha_prod_t = self.alphas_cumprod[timestep] + alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod + beta_prod_t = 1 - alpha_prod_t + beta_prod_t_prev = 1 - alpha_prod_t_prev + + variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev) + + return variance + + def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor: + return sample + + def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): + """ + Sets the discrete timesteps used for the diffusion chain (to be run before inference). + + Args: + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. + """ + + if num_inference_steps > self.config.num_train_timesteps: + raise ValueError( + f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:" + f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle" + f" maximal {self.config.num_train_timesteps} timesteps." + ) + + self.num_inference_steps = num_inference_steps + + if self.config.timestep_spacing == "linspace": + timesteps = ( + np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps) + .round()[::-1] + .copy() + .astype(np.int64) + ) + elif self.config.timestep_spacing == "leading": + step_ratio = self.config.num_train_timesteps // self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64) + timesteps += self.config.steps_offset + elif self.config.timestep_spacing == "trailing": + step_ratio = self.config.num_train_timesteps / self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = np.round(np.arange(self.config.num_train_timesteps, 0, -step_ratio)).astype(np.int64) + timesteps -= 1 + else: + raise ValueError( + f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'leading' or 'trailing'." + ) + + self.timesteps = torch.from_numpy(timesteps).to(device) + + def step( + self, + model_output: torch.Tensor, + timestep: int, + sample: torch.Tensor, + return_dict: bool = True, + ) -> Union[DDIMSchedulerOutput, Tuple]: + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps + + # 2. compute alphas, betas + alpha_prod_t = self.alphas_cumprod[timestep] + alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod + + beta_prod_t = 1 - alpha_prod_t + + # 3. compute predicted original sample from predicted noise also called + if self.config.prediction_type == "epsilon": + pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) + elif self.config.prediction_type == "sample": + pred_original_sample = model_output + elif self.config.prediction_type == "v_prediction": + pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" + " `v_prediction`" + ) + + a_t = ((1 - alpha_prod_t_prev) / (1 - alpha_prod_t)) ** 0.5 + b_t = alpha_prod_t_prev**0.5 - alpha_prod_t**0.5 * a_t + + prev_sample = a_t * sample + b_t * pred_original_sample + + if not return_dict: + return ( + prev_sample, + pred_original_sample, + ) + + return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample) + + def add_noise( + self, + original_samples: torch.Tensor, + noise: torch.Tensor, + timesteps: torch.IntTensor, + ) -> torch.Tensor: + # Make sure alphas_cumprod and timestep have same device and dtype as original_samples + # Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement + # for the subsequent add_noise calls + self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device) + alphas_cumprod = self.alphas_cumprod.to(dtype=original_samples.dtype) + timesteps = timesteps.to(original_samples.device) + + sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 + sqrt_alpha_prod = sqrt_alpha_prod.flatten() + while len(sqrt_alpha_prod.shape) < len(original_samples.shape): + sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) + + sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() + while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) + + noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise + return noisy_samples + + def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: torch.IntTensor) -> torch.Tensor: + # Make sure alphas_cumprod and timestep have same device and dtype as sample + self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device) + alphas_cumprod = self.alphas_cumprod.to(dtype=sample.dtype) + timesteps = timesteps.to(sample.device) + + sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 + sqrt_alpha_prod = sqrt_alpha_prod.flatten() + while len(sqrt_alpha_prod.shape) < len(sample.shape): + sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) + + sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() + while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape): + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) + + velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample + return velocity + + def __len__(self): + return self.config.num_train_timesteps \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/schedulers/scheduling_utils.py b/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/schedulers/scheduling_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d854366c7791c96c5e890b1aae0d4800ac657e1f --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/schedulers/scheduling_utils.py @@ -0,0 +1,113 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import importlib +import os +from dataclasses import dataclass +from enum import Enum +from typing import Optional, Union + +import torch +from huggingface_hub.utils import validate_hf_hub_args + +from diffusers.utils import BaseOutput, PushToHubMixin + + +SCHEDULER_CONFIG_NAME = "scheduler_config.json" + + +class KarrasDiffusionSchedulers(Enum): + DDIMScheduler = 1 + DDPMScheduler = 2 + PNDMScheduler = 3 + LMSDiscreteScheduler = 4 + EulerDiscreteScheduler = 5 + HeunDiscreteScheduler = 6 + EulerAncestralDiscreteScheduler = 7 + DPMSolverMultistepScheduler = 8 + DPMSolverSinglestepScheduler = 9 + KDPM2DiscreteScheduler = 10 + KDPM2AncestralDiscreteScheduler = 11 + DEISMultistepScheduler = 12 + UniPCMultistepScheduler = 13 + DPMSolverSDEScheduler = 14 + EDMEulerScheduler = 15 + + +AysSchedules = { + "StableDiffusionTimesteps": [999, 850, 736, 645, 545, 455, 343, 233, 124, 24], + "StableDiffusionSigmas": [14.615, 6.475, 3.861, 2.697, 1.886, 1.396, 0.963, 0.652, 0.399, 0.152, 0.0], + "StableDiffusionXLTimesteps": [999, 845, 730, 587, 443, 310, 193, 116, 53, 13], + "StableDiffusionXLSigmas": [14.615, 6.315, 3.771, 2.181, 1.342, 0.862, 0.555, 0.380, 0.234, 0.113, 0.0], + "StableDiffusionVideoSigmas": [700.00, 54.5, 15.886, 7.977, 4.248, 1.789, 0.981, 0.403, 0.173, 0.034, 0.0], +} + + +@dataclass +class SchedulerOutput(BaseOutput): + """ + Base class for the output of a scheduler's `step` function. + + Args: + prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + """ + + prev_sample: torch.Tensor + + +class SchedulerMixin(PushToHubMixin): + + config_name = SCHEDULER_CONFIG_NAME + _compatibles = [] + has_compatibles = True + + @classmethod + @validate_hf_hub_args + def from_pretrained( + cls, + pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None, + subfolder: Optional[str] = None, + return_unused_kwargs=False, + **kwargs, + ): + + config, kwargs, _ = cls.load_config( + pretrained_model_name_or_path=pretrained_model_name_or_path, + subfolder=subfolder, + return_unused_kwargs=True, + return_commit_hash=True, + **kwargs, + ) + return cls.from_config(config, return_unused_kwargs=return_unused_kwargs, **kwargs) + + @property + def compatibles(self): + """ + Returns all schedulers that are compatible with this scheduler + + Returns: + `List[SchedulerMixin]`: List of compatible schedulers + """ + return self._get_compatibles() + + @classmethod + def _get_compatibles(cls): + compatible_classes_str = list(set([cls.__name__] + cls._compatibles)) + diffusers_library = importlib.import_module(__name__.split(".")[0]) + compatible_classes = [ + getattr(diffusers_library, c) for c in compatible_classes_str if hasattr(diffusers_library, c) + ] + return compatible_classes \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/vae/__init__.py b/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/vae/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogview3/inference_cogview3plus.py b/MindIE/MindIE-Torch/built-in/foundation/cogview3/inference_cogview3plus.py new file mode 100644 index 0000000000000000000000000000000000000000..c3bb1f2ebbf59fe976dab8533c293a6d2c76afe9 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/cogview3/inference_cogview3plus.py @@ -0,0 +1,105 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import logging +import time + +import torch + +from cogview3plus import CogView3PlusPipeline + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +def parse_arguments(): + parser = argparse.ArgumentParser(description="Generate an image using the CogView3-Plus-3B model.") + + # Define arguments for prompt, model path, etc. + parser.add_argument( + "--prompt", + type=list, + default=[ + "A vibrant cherry red sports car sits proudly under the gleaming sun, \ + its polished exterior smooth and flawless, casting a mirror-like reflection. \ + The car features a low, aerodynamic body, angular headlights that gaze forward like predatory eyes, \ + and a set of black, high-gloss racing rims that contrast starkly with the red. \ + A subtle hint of chrome embellishes the grille and exhaust, \ + while the tinted windows suggest a luxurious and private interior. \ + he scene conveys a sense of speed and elegance, \ + the car appearing as if it's about to burst into a sprint along a coastal road, \ + with the ocean's azure waves crashing in the background." + ], + help="The text description for generating the image." + ) + parser.add_argument( + "--model_path", type=str, default="/data/CogView3B", help="Path to the pre-trained model." + ) + parser.add_argument( + "--guidance_scale", type=float, default=7.0, help="The guidance scale for classifier-free guidance." + ) + parser.add_argument( + "--num_images_per_prompt", type=int, default=1, help="Number of images to generate per prompt." + ) + parser.add_argument("--num_inference_steps", type=int, default=50, help="Number of denoising steps for inference.") + parser.add_argument("--width", type=int, default=1024, help="Width of the generated image.") + parser.add_argument("--height", type=int, default=1024, help="Height of the generated image.") + parser.add_argument("--output_path", type=str, default="cogview3.png", help="Path to save the generated image.") + parser.add_argument("--dtype", type=str, default="bf16", help="bf16 or fp16") + parser.add_argument("--device_id", type=int, default=7, help="NPU device id") + + return parser.parse_args() + + +def infer(args): + torch.npu.set_device(args.device_id) + dtype = torch.bfloat16 if args.dtype == "bf16" else torch.float16 + + # Load the pre-trained model with the specified precision + pipe = CogView3PlusPipeline.from_pretrained(args.model_path, torch_dtype=dtype).to("npu") + + use_time = 0 + loops = 5 + for i in range(loops): + start_time = time.time() + # Generate the image based on the prompt + image = pipe( + prompt=args.prompt[0], + guidance_scale=args.guidance_scale, + num_images_per_prompt=args.num_images_per_prompt, + num_inference_steps=args.num_inference_steps, + image_size=(args.height, args.width), + ).images[0] + + if i >= 2: + use_time += time.time() - start_time + logger.info("current_time is %.3f )", time.time() - start_time) + + torch.npu.empty_cache() + + logger.info("use_time is %.3f)", use_time / 3) + + # Save the generated image to the local file system + image.save(args.output_path) + + print(f"Image saved to {args.output_path}") + + +if __name__ == "__main__": + inference_args = parse_arguments() + infer(inference_args) + diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogview3/requirents.txt b/MindIE/MindIE-Torch/built-in/foundation/cogview3/requirents.txt new file mode 100644 index 0000000000000000000000000000000000000000..1600434700000cd99216b7d6179326b0d54380e0 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/cogview3/requirents.txt @@ -0,0 +1,8 @@ +deepspeed==0.16.1 +transformers==4.47.1 +gradio==5.9.1 +accelerate==1.0.1 +diffusers==0.31.0 +sentencepiece==0.2.0 +torch==2.4.0 +openai==1.58.1 \ No newline at end of file