From 8b0018467d544d5869b35c2b61c43de35d334771 Mon Sep 17 00:00:00 2001 From: Shitong Li Date: Mon, 22 Jul 2024 17:29:42 +0800 Subject: [PATCH 1/2] =?UTF-8?q?Opensora=E6=A1=86=E6=9E=B6=E8=BF=81?= =?UTF-8?q?=E7=A7=BBMegatron(Mindspeed):=201.=E6=96=B0=E5=A2=9Epretrain=5F?= =?UTF-8?q?opensora.py\pretrain=5Fopensora.sh=20=E8=AE=AD=E7=BB=83?= =?UTF-8?q?=E8=84=9A=E6=9C=AC=202.MindSpeed=20PP\VPP=20=E9=80=82=E9=85=8D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../opensora/models/stdit/__init__.py | 1 + .../opensora/models/stdit/stdit.py | 4 +- .../opensora/models/stdit/stdit_mindspeed.py | 509 ++++++++++++++++++ .../OpenSora1.0/opensora/models/vae/vae.py | 31 +- .../opensora/schedulers/__init__.py | 2 +- .../opensora/schedulers/iddpm/__init__.py | 77 +++ .../iddpm/gaussian_diffusion_mindspeed.py | 217 ++++++++ .../schedulers/iddpm/respace_mindspeed.py | 50 ++ 8 files changed, 885 insertions(+), 6 deletions(-) create mode 100644 PyTorch/built-in/mlm/OpenSora1.0/opensora/models/stdit/stdit_mindspeed.py create mode 100644 PyTorch/built-in/mlm/OpenSora1.0/opensora/schedulers/iddpm/gaussian_diffusion_mindspeed.py create mode 100644 PyTorch/built-in/mlm/OpenSora1.0/opensora/schedulers/iddpm/respace_mindspeed.py diff --git a/PyTorch/built-in/mlm/OpenSora1.0/opensora/models/stdit/__init__.py b/PyTorch/built-in/mlm/OpenSora1.0/opensora/models/stdit/__init__.py index 5ca2cc91f8..712eba543f 100644 --- a/PyTorch/built-in/mlm/OpenSora1.0/opensora/models/stdit/__init__.py +++ b/PyTorch/built-in/mlm/OpenSora1.0/opensora/models/stdit/__init__.py @@ -1 +1,2 @@ from .stdit import STDiT +from .stdit_mindspeed import STDiT_mindspeed diff --git a/PyTorch/built-in/mlm/OpenSora1.0/opensora/models/stdit/stdit.py b/PyTorch/built-in/mlm/OpenSora1.0/opensora/models/stdit/stdit.py index 84cf685077..bf10b2ff46 100644 --- a/PyTorch/built-in/mlm/OpenSora1.0/opensora/models/stdit/stdit.py +++ b/PyTorch/built-in/mlm/OpenSora1.0/opensora/models/stdit/stdit.py @@ -29,7 +29,6 @@ from opensora.registry import MODELS from opensora.utils.ckpt_utils import load_checkpoint from opensora.utils.config_utils import parse_configs -cfg = parse_configs(training=True) class STDiTBlock(nn.Module): def __init__( @@ -48,6 +47,7 @@ class STDiTBlock(nn.Module): self.hidden_size = hidden_size self.enable_flashattn = enable_flashattn self._enable_sequence_parallelism = enable_sequence_parallelism + cfg = parse_configs(training=True) if enable_sequence_parallelism and cfg.context_parallel_algo != "dsp_cp_algo": self.attn_cls = AttentionWithCp self.mha_cls = SeqParallelMultiHeadCrossAttention @@ -102,6 +102,7 @@ class STDiTBlock(nn.Module): x = x + self.drop_path(gate_msa * x_s) # temporal to spatital switch in dsp + cfg = parse_configs(training=True) if self._enable_sequence_parallelism and cfg.context_parallel_algo == "dsp_cp_algo": x = rearrange(x, "B (T S) C -> B T S C", T=self.d_t, S=self.d_s) x = all_to_all(x, get_sequence_parallel_group(), scatter_dim=2, gather_dim=1) @@ -243,6 +244,7 @@ class STDiT(nn.Module): x (torch.Tensor): output latent representation; of shape [B, C, T, H, W] """ + cfg = parse_configs(training=True) x = x.to(self.dtype) timestep = timestep.to(self.dtype) y = y.to(self.dtype) diff --git a/PyTorch/built-in/mlm/OpenSora1.0/opensora/models/stdit/stdit_mindspeed.py b/PyTorch/built-in/mlm/OpenSora1.0/opensora/models/stdit/stdit_mindspeed.py new file mode 100644 index 0000000000..e6b407d20a --- /dev/null +++ b/PyTorch/built-in/mlm/OpenSora1.0/opensora/models/stdit/stdit_mindspeed.py @@ -0,0 +1,509 @@ +import math + +import numpy as np +import torch +import torch.distributed as dist +import torch.nn as nn +from einops import rearrange +from timm.models.layers import DropPath +from timm.models.vision_transformer import Mlp + +from megatron.core import mpu, tensor_parallel +from megatron.training import get_args +from megatron.legacy.model.module import MegatronModule +from megatron.training.arguments import core_transformer_config_from_args +from megatron.legacy.model.transformer import ParallelAttention +from megatron.legacy.model.enums import AttnMaskType, AttnType + +from opensora.acceleration.checkpoint import auto_grad_checkpoint +from opensora.acceleration.communications import (gather_forward_split_backward, split_forward_gather_backward, + all_to_all) +from opensora.models.layers.blocks_mindspeed import ( + CaptionEmbedder, + PatchEmbed3D, + ParallelMultiHeadCrossAttention, + T2IFinalLayer, + TimestepEmbedder, + approx_gelu, + get_1d_sincos_pos_embed, + get_2d_sincos_pos_embed, + get_layernorm, + t2i_modulate, +) +from opensora.registry import MODELS + + +class STDiTBlock(nn.Module): + def __init__( + self, + d_s=None, + d_t=None, + mlp_ratio=4.0, + drop_path=0.0, + enable_layernorm_kernel=False, + enable_sequence_parallelism=False, + config=None, + layer_number=0, + ): + super().__init__() + self.config = config + self.layer_number = layer_number + self._enable_sequence_parallelism = enable_sequence_parallelism + self.attn_cls = ParallelAttention + self.mha_cls = ParallelMultiHeadCrossAttention + + self.norm1 = get_layernorm(config.hidden_size, eps=1e-6, affine=False, use_kernel=enable_layernorm_kernel) + + self.attn = self.attn_cls( + config, + layer_number, + attention_type=AttnType.self_attn, + attn_mask_type=AttnMaskType.padding) + self.cross_attn = self.mha_cls(config=config) + self.norm2 = get_layernorm(config.hidden_size, eps=1e-6, affine=False, use_kernel=enable_layernorm_kernel) + self.mlp = Mlp(in_features=config.hidden_size, hidden_features=int(config.hidden_size * mlp_ratio), + act_layer=approx_gelu, drop=0 + ) + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + self.scale_shift_table = nn.Parameter(torch.randn(6, config.hidden_size) / (config.hidden_size ** 0.5)) + self.sp_size = mpu.get_context_parallel_world_size() if mpu.get_context_parallel_world_size() > 0 else 1 + # temporal attention + self.d_s = d_s + assert d_t % self.sp_size == 0 + self.d_t = d_t // self.sp_size + + self.attn_temp = self.attn_cls( + config, + layer_number, + attention_type=AttnType.self_attn, + attn_mask_type=AttnMaskType.padding) + + def forward(self, x, y, t, mask=None, tpe=None): + B, N, C = x.shape + + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( + self.scale_shift_table[None] + t.reshape(B, 6, -1) + ).chunk(6, dim=1) + x_m = t2i_modulate(self.norm1(x), shift_msa, scale_msa) + + # spatial branch + x_s = rearrange(x_m, "B (T S) C -> S (B T) C", T=self.d_t, S=self.d_s) + x_s = self.attn(x_s, attention_mask=None)[0] + x_s = rearrange(x_s, "S (B T) C -> B (T S) C", T=self.d_t, S=self.d_s) + + x = x + self.drop_path(gate_msa * x_s) + + # temporal to spatital switch in dsp + if self._enable_sequence_parallelism and self.config.context_parallel_algo == "dsp_cp_algo": + x = rearrange(x, "B (T S) C -> B T S C", T=self.d_t, S=self.d_s) + x = all_to_all(x, mpu.get_context_parallel_group(), scatter_dim=2, gather_dim=1) + self.d_t = self.d_t * self.sp_size + self.d_s = self.d_s // self.sp_size + x = rearrange(x, "B T S C -> B (T S) C", T=self.d_t, S=self.d_s) + + # temporal branch + x_t = rearrange(x, "B (T S) C -> (B S) T C", T=self.d_t, S=self.d_s) + if tpe is not None: + x_t = x_t + tpe + x_t = rearrange(x_t, "(B S) T C -> T (B S) C", T=self.d_t, S=self.d_s) + x_t = self.attn_temp(x_t, attention_mask=None)[0] + x_t = rearrange(x_t, "T (B S) C -> B (T S) C", T=self.d_t, S=self.d_s) + x = x + self.drop_path(gate_msa * x_t) + + # spatital to temporal switch in dsp + if self._enable_sequence_parallelism and self.config.context_parallel_algo == "dsp_cp_algo": + x = rearrange(x, "B (T S) C -> B T S C", T=self.d_t, S=self.d_s) + x = all_to_all(x, mpu.get_context_parallel_group(), scatter_dim=1, gather_dim=2) + self.d_t = self.d_t // self.sp_size + self.d_s = self.d_s * self.sp_size + x = rearrange(x, "B T S C -> B (T S) C", T=self.d_t, S=self.d_s) + + # cross attn + x = x + self.cross_attn(x, y, mask) + + # mlp + x = x + self.drop_path(gate_mlp * self.mlp(t2i_modulate(self.norm2(x), shift_mlp, scale_mlp))) + + return x + + +@MODELS.register_module() +class STDiT_mindspeed(MegatronModule): + def __init__( + self, + input_size=(1, 32, 32), + in_channels=4, + patch_size=(1, 2, 2), + mlp_ratio=4.0, + class_dropout_prob=0.1, + pred_sigma=True, + drop_path=0.0, + no_temporal_pos_emb=False, + caption_channels=4096, + model_max_length=120, + dtype=torch.float32, + space_scale=1.0, + time_scale=1.0, + freeze=None, + enable_layernorm_kernel=False, + enable_sequence_parallelism=False, + pre_process=True, + post_process=True + ): + super().__init__(share_embeddings_and_output_weights=False) + args = get_args() + config = core_transformer_config_from_args(args) + self.config = config + self.pred_sigma = pred_sigma + self.in_channels = in_channels + self.out_channels = in_channels * 2 if pred_sigma else in_channels + self.hidden_size = config.hidden_size + self.patch_size = patch_size + self.input_size = input_size + num_patches = np.prod([input_size[i] // patch_size[i] for i in range(3)]) + self.num_patches = num_patches + self.num_temporal = input_size[0] // patch_size[0] + self.num_spatial = num_patches // self.num_temporal + self.dtype = dtype + self.no_temporal_pos_emb = no_temporal_pos_emb + self.depth = config.num_layers + self.mlp_ratio = mlp_ratio + self.enable_layernorm_kernel = enable_layernorm_kernel + self.space_scale = space_scale + self.time_scale = time_scale + self.pre_process = pre_process + self.post_process = post_process + self.enable_sequence_parallelism = enable_sequence_parallelism + self.sp_rank = mpu.get_context_parallel_rank() + + self.register_buffer("pos_embed", self.get_spatial_pos_embed()) + self.register_buffer("pos_embed_temporal", self.get_temporal_pos_embed()) + + is_first_stage = mpu.is_pipeline_first_stage() + is_last_stage = mpu.is_pipeline_last_stage() + self.x_embedder = PatchEmbed3D(patch_size, in_channels, self.hidden_size) if is_first_stage else None + self.t_embedder = TimestepEmbedder(self.hidden_size, config=config) if is_first_stage else None + self.t_block = ( + nn.Sequential( + nn.SiLU(), + tensor_parallel.ColumnParallelLinear( + self.hidden_size, + 6 * self.hidden_size, + config=config, + init_method=config.init_method, + bias=True, + gather_output=True) + ) if is_first_stage else None + ) + self.y_embedder = ( + CaptionEmbedder( + in_channels=caption_channels, + hidden_size=self.hidden_size, + uncond_prob=class_dropout_prob, + act_layer=approx_gelu, + token_num=model_max_length, + ) + ) if is_first_stage else None + + # Number of layers. + drop_path = [x.item() for x in torch.linspace(0, drop_path, self.depth)] + self.num_layers = self._get_num_layers() + + if config.virtual_pipeline_model_parallel_size is not None: + assert self.num_layers % config.virtual_pipeline_model_parallel_size == 0, \ + 'num_layers_per_stage must be divisible by ' \ + 'virtual_pipeline_model_parallel_size' + # Number of layers in each model chunk is the number of layers in the stage, + # divided by the number of model chunks in a stage. + self.num_layers = self.num_layers // config.virtual_pipeline_model_parallel_size + # With 8 layers, 2 stages, and 4 model chunks, we want an assignment of + # layers to stages like (each list is a model chunk): + # Stage 0: [0] [2] [4] [6] + # Stage 1: [1] [3] [5] [7] + # With 8 layers, 2 stages, and 2 virtual stages, we want an assignment of + # layers to stages like (each list is a model chunk): + # Stage 0: [0, 1] [4, 5] + # Stage 1: [2, 3] [6, 7] + offset = mpu.get_virtual_pipeline_model_parallel_rank() * ( + config.num_layers // config.virtual_pipeline_model_parallel_size) + \ + (mpu.get_pipeline_model_parallel_rank() * self.num_layers) + else: + # Each stage gets a contiguous set of layers. + offset = mpu.get_pipeline_model_parallel_rank() * self.num_layers + + self.blocks = torch.nn.ModuleList( + [self.build_layer(drop_path, i + offset, enable_sequence_parallelism) for i in range(self.num_layers)]) + + self.final_layer = T2IFinalLayer(self.hidden_size, np.prod(self.patch_size), self.out_channels, config=config) \ + if is_last_stage else None + + + if freeze is not None: + assert freeze in ["not_temporal", "text"] + if freeze == "not_temporal": + self.freeze_not_temporal() + elif freeze == "text": + self.freeze_text() + + + def _get_num_layers(self): + def _get_pipelinse_stage_split_index(pp_word_size): + ceil = math.ceil(self.depth / pp_word_size) + floor = ceil - 1 + for index in range(pp_word_size): + if ceil * (index + 1) + floor * (pp_word_size - (index + 1)) == self.depth: + return index + pp_word_size = mpu.get_pipeline_model_parallel_world_size() + pp_rank = mpu.get_pipeline_model_parallel_rank() + if get_args().virtual_pipeline_model_parallel_size is not None: + assert self.depth % pp_word_size == 0, \ + 'num_layers must be divisible by pipeline_model_parallel_world_size' + return self.depth // pp_word_size + else: + if self.depth % pp_word_size == 0: + return self.depth // pp_word_size + index = _get_pipelinse_stage_split_index(self.depth, pp_word_size) + ceil = math.ceil(self.depth / pp_word_size) + if pp_rank <= index: + return ceil + else: + return ceil - 1 + + def set_input_tensor(self, input_tensor): + """Set input tensor to be used instead of forward()'s input. + + When doing pipeline parallelism the input from the previous + stage comes from communication, not from the input, so the + model's forward_step_func won't have it. This function is thus + used by internal code to bypass the input provided by the + forward_step_func""" + self.input_tensor = input_tensor + + def build_layer(self, drop_path, layer_number, enable_sequence_parallelism): + return STDiTBlock( + mlp_ratio=self.mlp_ratio, + drop_path=drop_path[layer_number], + enable_layernorm_kernel=self.enable_layernorm_kernel, + enable_sequence_parallelism=enable_sequence_parallelism, + d_t=self.num_temporal, + d_s=self.num_spatial, + config=self.config, + layer_number=layer_number, + ) + + def _get_layer(self, layer_number): + return self.layers[layer_number] + + def forward(self, x, timestep, y, x_0, noise, mask=None): + """ + Forward pass of STDiT. + Args: + x (torch.Tensor): latent representation of video; of shape [B, C, T, H, W] + timestep (torch.Tensor): diffusion time steps; of shape [B] + y (torch.Tensor): representation of prompts; of shape [B, 1, N_token, C] + mask (torch.Tensor): mask for selecting prompt tokens; of shape [B, N_token] + + Returns: + x (torch.Tensor): output latent representation; of shape [B, C, T, H, W] + """ + + vpp_rank = mpu.get_virtual_pipeline_model_parallel_rank() + if vpp_rank is None: + vpp_rank = 0 + pp_rank = mpu.get_pipeline_model_parallel_rank() + pp_size = mpu.get_pipeline_model_parallel_world_size() + + layer_index = pp_size * vpp_rank + pp_rank + + # set x from the output of last stage in PP + rank = torch.distributed.get_rank() + + if mpu.is_pipeline_first_stage(): + x_t = x.clone() + else: + x, x_t, y, timestep, t0, mask, x_0, noise, t = self.input_tensor + + timestep_clone = timestep.clone() + timestep = timestep.to(torch.int64) + mask = mask.to(torch.int64) + + if mpu.is_pipeline_first_stage(): + # embedding + # [B, N, C] + x = self.x_embedder(x) + x = rearrange(x, "B (T S) C -> B T S C", T=self.num_temporal, S=self.num_spatial) + x = x + self.pos_embed + x = rearrange(x, "B T S C -> B (T S) C") + + # shard over the sequence dim if sp is enabled + # [B, C] + t = self.t_embedder(timestep, dtype=x.dtype) + # [B, C] + t0, _ = self.t_block(t) + # [B, 1, N_token, C] + y = self.y_embedder(y, self.training) + + if mpu.get_context_parallel_world_size() > 1: + x = split_forward_gather_backward(x, mpu.get_context_parallel_group(), dim=1, grad_scale="down") + + mask_clone = mask.clone() + y_clone = y.clone() + + if mask is not None: + if mask.shape[0] != y.shape[0]: + mask = mask.repeat(y.shape[0] // mask.shape[0], 1) + mask = mask.squeeze(1).squeeze(1) + y = y.squeeze(1).masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1]) + y_lens = mask.sum(dim=1).tolist() + else: + if mpu.is_pipeline_first_stage(): + y_lens = [y.shape[2]] * y.shape[0] + y = y.squeeze(1).view(1, -1, x.shape[-1]) + + for i, block in enumerate(self.blocks): + if mpu.is_pipeline_first_stage() and i == 0: + if mpu.get_context_parallel_world_size() > 1: + tpe = torch.chunk( + self.pos_embed_temporal, mpu.get_context_parallel_world_size(), dim=1 + )[self.sp_rank].contiguous() + else: + tpe = self.pos_embed_temporal + else: + tpe = None + + x = auto_grad_checkpoint(block, x, y, t0, y_lens, tpe) + + if mpu.get_context_parallel_world_size() > 1: + x = gather_forward_split_backward(x, mpu.get_context_parallel_group(), dim=1, grad_scale="up") + + if mpu.is_pipeline_last_stage(): + x = self.final_layer(x, t) + # [B, C_out, T, H, W] + x = self.unpatchify(x) + + # cast to float32 for better accuracy + x = x.to(torch.float32) + else: + x = x.to(self.dtype) + x_t = x_t.to(self.dtype) + y_clone = y_clone.to(self.dtype) + timestep_clone = timestep_clone.to(torch.float32) + t0 = t0.to(self.dtype) + mask_clone = mask_clone.to(torch.float32) + x_0 = x_0.to(self.dtype) + noise = noise.to(self.dtype) + t = t.to(self.dtype) + + if pp_size <= 1: + return x + + return x, x_t, y_clone, timestep_clone, t0, mask_clone, x_0, noise, t + + def unpatchify(self, x): + """ + Args: + x (torch.Tensor): of shape [B, N, C] + + Return: + x (torch.Tensor): of shape [B, C_out, T, H, W] + """ + + N_t, N_h, N_w = [self.input_size[i] // self.patch_size[i] for i in range(3)] + T_p, H_p, W_p = self.patch_size + x = rearrange( + x, + "B (N_t N_h N_w) (T_p H_p W_p C_out) -> B C_out (N_t T_p) (N_h H_p) (N_w W_p)", + N_t=N_t, + N_h=N_h, + N_w=N_w, + T_p=T_p, + H_p=H_p, + W_p=W_p, + C_out=self.out_channels, + ) + return x + + def unpatchify_old(self, x): + c = self.out_channels + t, h, w = [self.input_size[i] // self.patch_size[i] for i in range(3)] + pt, ph, pw = self.patch_size + + x = x.reshape(shape=(x.shape[0], t, h, w, pt, ph, pw, c)) + x = rearrange(x, "n t h w r p q c -> n c t r h p w q") + imgs = x.reshape(shape=(x.shape[0], c, t * pt, h * ph, w * pw)) + return imgs + + def get_spatial_pos_embed(self, grid_size=None): + if grid_size is None: + grid_size = self.input_size[1:] + pos_embed = get_2d_sincos_pos_embed( + self.hidden_size, + (grid_size[0] // self.patch_size[1], grid_size[1] // self.patch_size[2]), + scale=self.space_scale, + ) + pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0).requires_grad_(False) + return pos_embed + + def get_temporal_pos_embed(self): + pos_embed = get_1d_sincos_pos_embed( + self.hidden_size, + self.input_size[0] // self.patch_size[0], + scale=self.time_scale, + ) + pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0).requires_grad_(False) + return pos_embed + + def freeze_not_temporal(self): + for n, p in self.named_parameters(): + if "attn_temp" not in n: + p.requires_grad = False + + def freeze_text(self): + for n, p in self.named_parameters(): + if "cross_attn" in n: + p.requires_grad = False + + def initialize_temporal(self): + for block in self.blocks: + nn.init.constant_(block.attn_temp.dense.weight, 0) + nn.init.constant_(block.attn_temp.dense.bias, 0) + + def initialize_weights(self): + # Initialize transformer layers: + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + + self.apply(_basic_init) + + # Initialize patch_embed like nn.Linear (instead of nn.Conv2d): + if mpu.is_pipeline_first_stage(): + w = self.x_embedder.proj.weight.data + nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + + # Initialize timestep embedding MLP: + nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) + nn.init.normal_(self.t_block[1].weight, std=0.02) + + # Initialize caption embedding MLP: + nn.init.normal_(self.y_embedder.y_proj.fc1.weight, std=0.02) + nn.init.normal_(self.y_embedder.y_proj.fc2.weight, std=0.02) + + # Zero-out adaLN modulation layers in PixArt blocks: + for block in self.blocks: + nn.init.constant_(block.cross_attn.proj.weight, 0) + nn.init.constant_(block.cross_attn.proj.bias, 0) + + if mpu.is_pipeline_last_stage(): + # Zero-out output layers: + nn.init.constant_(self.final_layer.linear.weight, 0) + nn.init.constant_(self.final_layer.linear.bias, 0) + + +@MODELS.register_module("STDiT-XL/2/mindspeed") +def STDiT_XL_2(**kwargs): + model = STDiT_mindspeed(patch_size=(1, 2, 2), **kwargs) + return model diff --git a/PyTorch/built-in/mlm/OpenSora1.0/opensora/models/vae/vae.py b/PyTorch/built-in/mlm/OpenSora1.0/opensora/models/vae/vae.py index 7e800f6ae3..83e1b784ef 100644 --- a/PyTorch/built-in/mlm/OpenSora1.0/opensora/models/vae/vae.py +++ b/PyTorch/built-in/mlm/OpenSora1.0/opensora/models/vae/vae.py @@ -5,7 +5,8 @@ from einops import rearrange from opensora.registry import MODELS from opensora.acceleration.communications import gather_forward_split_backward, split_forward_gather_backward -from opensora.acceleration.parallel_states import get_sequence_parallel_group +from megatron.core import mpu +import mindspeed @MODELS.register_module() @@ -21,8 +22,23 @@ class VideoAutoencoderKL(nn.Module): def encode(self, x): # x: (B, C, T, H, W) B = x.shape[0] + dim_size = x.size(2) + tp_cp_size = mindspeed.core.parallel_state.get_tensor_and_context_parallel_world_size() + cp_size = mpu.get_context_parallel_world_size() + enable_tp_cp_parallel = False + enable_cp_parallel = False if self.enable_sequence_parallelism: - x = split_forward_gather_backward(x, get_sequence_parallel_group(), dim=2, grad_scale="down") + if dim_size % tp_cp_size == 0: + x = split_forward_gather_backward( + x, mindspeed.core.parallel_state.get_tensor_and_context_parallel_group(), dim=2, grad_scale="down") + enable_tp_cp_parallel = True + elif dim_size % cp_size == 0: + x = split_forward_gather_backward(x, mpu.get_context_parallel_group(), dim=2, grad_scale="down") + enable_cp_parallel = True + else: + print(f"Warning: dim_size({dim_size}) is not divisible by tp_cp_size({tp_cp_size}) or " + f"cp_size({cp_size}), so parallelism is not used") + x = rearrange(x, "B C T H W -> (B T) C H W") if self.micro_batch_size is None: @@ -31,14 +47,21 @@ class VideoAutoencoderKL(nn.Module): bs = self.micro_batch_size x_out = [] for i in range(0, x.shape[0], bs): - x_bs = x[i : i + bs] + x_bs = x[i: i + bs] x_bs = self.module.encode(x_bs).latent_dist.sample().mul_(0.18215) x_out.append(x_bs) x = torch.cat(x_out, dim=0) x = rearrange(x, "(B T) C H W -> B C T H W", B=B) if self.enable_sequence_parallelism: - x = gather_forward_split_backward(x, get_sequence_parallel_group(), dim=2, grad_scale="up") + if enable_tp_cp_parallel: + x = gather_forward_split_backward( + x, mindspeed.core.parallel_state.get_tensor_and_context_parallel_group(), dim=2, grad_scale="up") + elif enable_cp_parallel: + x = gather_forward_split_backward(x, mpu.get_context_parallel_group(), dim=2, grad_scale="up") + else: + print(f"Warning: dim_size({dim_size}) is not divisible by tp_cp_size({tp_cp_size}) or " + f"cp_size({cp_size}), so parallelism is not used") return x diff --git a/PyTorch/built-in/mlm/OpenSora1.0/opensora/schedulers/__init__.py b/PyTorch/built-in/mlm/OpenSora1.0/opensora/schedulers/__init__.py index 97ea76f92f..09764f7dcf 100644 --- a/PyTorch/built-in/mlm/OpenSora1.0/opensora/schedulers/__init__.py +++ b/PyTorch/built-in/mlm/OpenSora1.0/opensora/schedulers/__init__.py @@ -1,2 +1,2 @@ from .dpms import DPMS -from .iddpm import IDDPM +from .iddpm import IDDPM, IDDPM_mindspeed diff --git a/PyTorch/built-in/mlm/OpenSora1.0/opensora/schedulers/iddpm/__init__.py b/PyTorch/built-in/mlm/OpenSora1.0/opensora/schedulers/iddpm/__init__.py index 2061dc3a00..0b9d5bacfe 100644 --- a/PyTorch/built-in/mlm/OpenSora1.0/opensora/schedulers/iddpm/__init__.py +++ b/PyTorch/built-in/mlm/OpenSora1.0/opensora/schedulers/iddpm/__init__.py @@ -6,6 +6,7 @@ from opensora.registry import SCHEDULERS from . import gaussian_diffusion as gd from .respace import SpacedDiffusion, space_timesteps +from .respace_mindspeed import SpacedDiffusion_mindspeed @SCHEDULERS.register_module("iddpm") @@ -84,6 +85,82 @@ class IDDPM(SpacedDiffusion): return samples +@SCHEDULERS.register_module("iddpm_mindspeed") +class IDDPM_mindspeed(SpacedDiffusion_mindspeed): + def __init__( + self, + num_sampling_steps=None, + timestep_respacing=None, + noise_schedule="linear", + use_kl=False, + sigma_small=False, + predict_xstart=False, + learn_sigma=True, + rescale_learned_sigmas=False, + diffusion_steps=1000, + cfg_scale=4.0, + cfg_channel=None, + ): + betas = gd.get_named_beta_schedule(noise_schedule, diffusion_steps) + if use_kl: + loss_type = gd.LossType.RESCALED_KL + elif rescale_learned_sigmas: + loss_type = gd.LossType.RESCALED_MSE + else: + loss_type = gd.LossType.MSE + if num_sampling_steps is not None: + assert timestep_respacing is None + timestep_respacing = str(num_sampling_steps) + if timestep_respacing is None or timestep_respacing == "": + timestep_respacing = [diffusion_steps] + super().__init__( + use_timesteps=space_timesteps(diffusion_steps, timestep_respacing), + betas=betas, + model_mean_type=(gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X), + model_var_type=( + (gd.ModelVarType.FIXED_LARGE if not sigma_small else gd.ModelVarType.FIXED_SMALL) + if not learn_sigma + else gd.ModelVarType.LEARNED_RANGE + ), + loss_type=loss_type, + # rescale_timesteps=rescale_timesteps, + ) + + self.cfg_scale = cfg_scale + self.cfg_channel = cfg_channel + + def sample( + self, + model, + text_encoder, + z_size, + prompts, + device, + additional_args=None, + ): + n = len(prompts) + z = torch.randn(n, *z_size, device=device) + z = torch.cat([z, z], 0) + model_args = text_encoder.encode(prompts) + y_null = text_encoder.null(n) + model_args["y"] = torch.cat([model_args["y"], y_null], 0) + if additional_args is not None: + model_args.update(additional_args) + + forward = partial(forward_with_cfg, model, cfg_scale=self.cfg_scale, cfg_channel=self.cfg_channel) + samples = self.p_sample_loop( + forward, + z.shape, + z, + clip_denoised=False, + model_kwargs=model_args, + progress=True, + device=device, + ) + samples, _ = samples.chunk(2, dim=0) + return samples + + def forward_with_cfg(model, x, timestep, y, cfg_scale, cfg_channel=None, **kwargs): # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb half = x[: len(x) // 2] diff --git a/PyTorch/built-in/mlm/OpenSora1.0/opensora/schedulers/iddpm/gaussian_diffusion_mindspeed.py b/PyTorch/built-in/mlm/OpenSora1.0/opensora/schedulers/iddpm/gaussian_diffusion_mindspeed.py new file mode 100644 index 0000000000..587d231cef --- /dev/null +++ b/PyTorch/built-in/mlm/OpenSora1.0/opensora/schedulers/iddpm/gaussian_diffusion_mindspeed.py @@ -0,0 +1,217 @@ +# Adapted from DiT + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# References: +# DiT: https://github.com/facebookresearch/DiT/tree/main +# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py +# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion +# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py +# -------------------------------------------------------- + + +import numpy as np +import torch as th + +from .diffusion_utils import discretized_gaussian_log_likelihood, normal_kl +from .gaussian_diffusion import GaussianDiffusion, ModelVarType, mean_flat, LossType, ModelMeanType + + +class GaussianDiffusion_mindspeed(GaussianDiffusion): + """ + Utilities for training and sampling diffusion opensora. + Original ported from this codebase: + https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42 + :param betas: a 1-D numpy array of betas for each diffusion timestep, + starting at T and going to 1. + """ + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def p_mean_variance(self, model_output, x, t, clip_denoised=True, denoised_fn=None): + """ + Apply the model to get p(x_{t-1} | x_t), as well as a prediction of + the initial x, x_0. + :param model_output: the model, which takes a signal and a batch of timesteps + as input. + :param x: the [N x C x ...] tensor at time t. + :param t: a 1-D Tensor of timesteps. + :param clip_denoised: if True, clip the denoised signal into [-1, 1]. + :param denoised_fn: if not None, a function which applies to the + x_start prediction before it is used to sample. Applies before + clip_denoised. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :return: a dict with the following keys: + - 'mean': the model mean output. + - 'variance': the model variance output. + - 'log_variance': the log of 'variance'. + - 'pred_xstart': the prediction for x_0. + """ + + B, C = x.shape[:2] + assert t.shape == (B,) + # model_output = model(x, t, **model_kwargs) + if isinstance(model_output, tuple): + model_output, extra = model_output + else: + extra = None + + if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]: + assert model_output.shape == (B, C * 2, *x.shape[2:]) # torch.Size([2, 4, 16, 32, 32]) + model_output, model_var_values = th.split(model_output, C, dim=1) + min_log = _extract_into_tensor(self.posterior_log_variance_clipped, t, x.shape) + max_log = _extract_into_tensor(np.log(self.betas), t, x.shape) + # The model_var_values is [-1, 1] for [min_var, max_var]. + frac = (model_var_values + 1) / 2 + model_log_variance = frac * max_log + (1 - frac) * min_log + model_variance = th.exp(model_log_variance) + else: + model_variance, model_log_variance = { + # for fixedlarge, we set the initial (log-)variance like so + # to get a better decoder log likelihood. + ModelVarType.FIXED_LARGE: ( + np.append(self.posterior_variance[1], self.betas[1:]), + np.log(np.append(self.posterior_variance[1], self.betas[1:])), + ), + ModelVarType.FIXED_SMALL: ( + self.posterior_variance, + self.posterior_log_variance_clipped, + ), + }[self.model_var_type] + model_variance = _extract_into_tensor(model_variance, t, x.shape) + model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape) + + def process_xstart(x): + if denoised_fn is not None: + x = denoised_fn(x) + if clip_denoised: + return x.clamp(-1, 1) + return x + + if self.model_mean_type == ModelMeanType.START_X: + pred_xstart = process_xstart(model_output) + else: + pred_xstart = process_xstart(self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output)) + model_mean, _, _ = self.q_posterior_mean_variance(x_start=pred_xstart, x_t=x, t=t) + + assert model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape + return { + "mean": model_mean, + "variance": model_variance, + "log_variance": model_log_variance, + "pred_xstart": pred_xstart, + "extra": extra, + } + + def _vb_terms_bpd(self, model_output, x_start, x_t, t, clip_denoised=True): + """ + Get a term for the variational lower-bound. + The resulting units are bits (rather than nats, as one might expect). + This allows for comparison to other papers. + :return: a dict with the following keys: + - 'output': a shape [N] tensor of NLLs or KLs. + - 'pred_xstart': the x_0 predictions. + """ + true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(x_start=x_start, x_t=x_t, t=t) + out = self.p_mean_variance(model_output, x_t, t, clip_denoised=clip_denoised) + kl = normal_kl(true_mean, true_log_variance_clipped, out["mean"], out["log_variance"]) + kl = mean_flat(kl) / np.log(2.0) + + decoder_nll = -discretized_gaussian_log_likelihood( + x_start, means=out["mean"], log_scales=0.5 * out["log_variance"] + ) + assert decoder_nll.shape == x_start.shape + decoder_nll = mean_flat(decoder_nll) / np.log(2.0) + + # At the first timestep return the decoder NLL, + # otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t)) + output = th.where((t == 0), decoder_nll, kl) + return {"output": output, "pred_xstart": out["pred_xstart"]} + + def training_losses(self, model_output, x_t, x_start, t, model_kwargs=None, noise=None): + """ + Compute training losses for a single timestep. + :param model_output: the model to evaluate loss on. + :param x_start: the [N x C x ...] tensor of inputs. + :param t: a batch of timestep indices. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :param noise: if specified, the specific Gaussian noise to try to remove. + :return: a dict with the key "loss" containing a tensor of shape [N]. + Some mean or variance settings may also have other keys. + """ + if model_kwargs is None: + model_kwargs = {} + if noise is None: + noise = th.randn_like(x_start) + + t = t.to(th.int64) + + terms = {} + if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL: + terms["loss"] = self._vb_terms_bpd( + model_output=model_output, + x_start=x_start, + x_t=x_t, + t=t, + clip_denoised=False + )["output"] + if self.loss_type == LossType.RESCALED_KL: + terms["loss"] *= self.num_timesteps + elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE: + if self.model_var_type in [ + ModelVarType.LEARNED, + ModelVarType.LEARNED_RANGE, + ]: + B, C = x_start.shape[:2] + assert model_output.shape == (B, C * 2, *x_t.shape[2:]) + model_output, model_var_values = th.split(model_output, C, dim=1) + # Learn the variance using the variational bound, but don't let + # it affect our mean prediction. + frozen_out = th.cat([model_output.detach(), model_var_values], dim=1) + terms["vb"] = self._vb_terms_bpd( + model_output=frozen_out, + x_start=x_start, + x_t=x_t, + t=t, + clip_denoised=False, + )["output"] + if self.loss_type == LossType.RESCALED_MSE: + # Divide by 1000 for equivalence with initial implementation. + # Without a factor of 1/1000, the VB term hurts the MSE term. + terms["vb"] *= self.num_timesteps / 1000.0 + + target = { + ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance(x_start=x_start, x_t=x_t, t=t)[0], + ModelMeanType.START_X: x_start, + ModelMeanType.EPSILON: noise, + }[self.model_mean_type] + assert model_output.shape == target.shape == x_start.shape + terms["mse"] = mean_flat((target - model_output) ** 2) + if "vb" in terms: + terms["loss"] = terms["mse"] + terms["vb"] + else: + terms["loss"] = terms["mse"] + return terms + else: + raise NotImplementedError(self.loss_type) + + return terms + +def _extract_into_tensor(arr, timesteps, broadcast_shape): + """ + Extract values from a 1-D numpy array for a batch of indices. + :param arr: the 1-D numpy array. + :param timesteps: a tensor of indices into the array to extract. + :param broadcast_shape: a larger shape of K dimensions with the batch + dimension equal to the length of timesteps. + :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims. + """ + timesteps = timesteps.to(th.int64) + res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float() + + while len(res.shape) < len(broadcast_shape): + res = res[..., None] + return res + th.zeros(broadcast_shape, device=timesteps.device) diff --git a/PyTorch/built-in/mlm/OpenSora1.0/opensora/schedulers/iddpm/respace_mindspeed.py b/PyTorch/built-in/mlm/OpenSora1.0/opensora/schedulers/iddpm/respace_mindspeed.py new file mode 100644 index 0000000000..2dd69353af --- /dev/null +++ b/PyTorch/built-in/mlm/OpenSora1.0/opensora/schedulers/iddpm/respace_mindspeed.py @@ -0,0 +1,50 @@ +# Adapted from DiT + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# References: +# DiT: https://github.com/facebookresearch/DiT/tree/main +# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py +# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion +# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py +# -------------------------------------------------------- + + +import numpy as np + +from .gaussian_diffusion_mindspeed import GaussianDiffusion_mindspeed +from .respace import _WrappedModel + + +class SpacedDiffusion_mindspeed(GaussianDiffusion_mindspeed): + def __init__(self, use_timesteps, **kwargs): + self.use_timesteps = set(use_timesteps) + self.timestep_map = [] + self.original_num_steps = len(kwargs["betas"]) + + base_diffusion = GaussianDiffusion_mindspeed(**kwargs) # pylint: disable=missing-kwoa + last_alpha_cumprod = 1.0 + new_betas = [] + for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod): + if i in self.use_timesteps: + new_betas.append(1 - alpha_cumprod / last_alpha_cumprod) + last_alpha_cumprod = alpha_cumprod + self.timestep_map.append(i) + kwargs["betas"] = np.array(new_betas) + super().__init__(**kwargs) + + def condition_mean(self, cond_fn, *args, **kwargs): + return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs) + + def condition_score(self, cond_fn, *args, **kwargs): + return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs) + + def _wrap_model(self, model): + if isinstance(model, _WrappedModel): + return model + return _WrappedModel(model, self.timestep_map, self.original_num_steps) + + def _scale_timesteps(self, t): + # Scaling is done by the wrapped model. + return t \ No newline at end of file -- Gitee From ff46605697b586bf75e775ea8c741944021ec2ba Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=82=93=E4=BD=B3?= Date: Thu, 25 Jul 2024 11:45:34 +0800 Subject: [PATCH 2/2] [built-in][Pytorch][OpenSora1.0]: blocks_mindspeed and mindspeed_test --- .../models/layers/blocks_mindspeed.py | 456 ++++++++++++++++++ .../opensora/models/vae/__init__.py | 1 + .../OpenSora1.0/opensora/models/vae/vae.py | 33 +- .../opensora/models/vae/vae_mindspeed.py | 89 ++++ .../tests/mindspeed_test/16x256x256.py | 31 ++ .../tests/mindspeed_test/pretrain_opensora.py | 232 +++++++++ .../tests/mindspeed_test/pretrain_opensora.sh | 87 ++++ 7 files changed, 901 insertions(+), 28 deletions(-) create mode 100644 PyTorch/built-in/mlm/OpenSora1.0/opensora/models/layers/blocks_mindspeed.py create mode 100644 PyTorch/built-in/mlm/OpenSora1.0/opensora/models/vae/vae_mindspeed.py create mode 100644 PyTorch/built-in/mlm/OpenSora1.0/tests/mindspeed_test/16x256x256.py create mode 100644 PyTorch/built-in/mlm/OpenSora1.0/tests/mindspeed_test/pretrain_opensora.py create mode 100644 PyTorch/built-in/mlm/OpenSora1.0/tests/mindspeed_test/pretrain_opensora.sh diff --git a/PyTorch/built-in/mlm/OpenSora1.0/opensora/models/layers/blocks_mindspeed.py b/PyTorch/built-in/mlm/OpenSora1.0/opensora/models/layers/blocks_mindspeed.py new file mode 100644 index 0000000000..d71c07954f --- /dev/null +++ b/PyTorch/built-in/mlm/OpenSora1.0/opensora/models/layers/blocks_mindspeed.py @@ -0,0 +1,456 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# References: +# PixArt: https://github.com/PixArt-alpha/PixArt-alpha +# Latte: https://github.com/Vchitect/Latte +# DiT: https://github.com/facebookresearch/DiT/tree/main +# GLIDE: https://github.com/openai/glide-text2im +# MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py +# -------------------------------------------------------- + +import math + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint +from timm.models.vision_transformer import Mlp +from megatron.core import mpu, tensor_parallel +from megatron import core +from megatron.training import get_args + +from opensora.acceleration.communications import all_to_all, split_forward_gather_backward +from opensora.utils.device_utils import is_npu_available +if not is_npu_available(): + import xformers.ops +else: + import torch_npu + + +approx_gelu = lambda: nn.GELU(approximate="tanh") + + +def get_layernorm(hidden_size: torch.Tensor, eps: float, affine: bool, use_kernel: bool): + if use_kernel: + try: + from apex.normalization import FusedLayerNorm + + return FusedLayerNorm(hidden_size, elementwise_affine=affine, eps=eps) + except ImportError: + raise RuntimeError("FusedLayerNorm not available. Please install apex.") + else: + return nn.LayerNorm(hidden_size, eps, elementwise_affine=affine) + + +def modulate(norm_func, x, shift, scale): + # Suppose x is (B, N, D), shift is (B, D), scale is (B, D) + dtype = x.dtype + x = norm_func(x.to(torch.float32)).to(dtype) + x = x * (scale.unsqueeze(1) + 1) + shift.unsqueeze(1) + x = x.to(dtype) + return x + + +def t2i_modulate(x, shift, scale): + return x * (1 + scale) + shift + + +# =============================================== +# General-purpose Layers +# =============================================== + + +class PatchEmbed3D(nn.Module): + """Video to Patch Embedding. + + Args: + patch_size (int): Patch token size. Default: (2,4,4). + in_chans (int): Number of input video channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__( + self, + patch_size=(2, 4, 4), + in_chans=3, + embed_dim=96, + norm_layer=None, + flatten=True, + ): + super().__init__() + self.patch_size = patch_size + self.flatten = flatten + + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + """Forward function.""" + # padding + _, _, D, H, W = x.size() + if W % self.patch_size[2] != 0: + x = F.pad(x, (0, self.patch_size[2] - W % self.patch_size[2])) + if H % self.patch_size[1] != 0: + x = F.pad(x, (0, 0, 0, self.patch_size[1] - H % self.patch_size[1])) + if D % self.patch_size[0] != 0: + x = F.pad(x, (0, 0, 0, 0, 0, self.patch_size[0] - D % self.patch_size[0])) + # (B C T H W) + x = self.proj(x) + if self.norm is not None: + D, Wh, Ww = x.size(2), x.size(3), x.size(4) + x = x.flatten(2).transpose(1, 2) + x = self.norm(x) + x = x.transpose(1, 2).view(-1, self.embed_dim, D, Wh, Ww) + if self.flatten: + # BCTHW -> BNC + x = x.flatten(2).transpose(1, 2) + return x + + +class ParallelMultiHeadCrossAttention(nn.Module): + def __init__( + self, + attn_drop=0.0, + proj_drop=0.0, + config=None): + super(ParallelMultiHeadCrossAttention, self).__init__() + assert config.hidden_size % config.num_attention_heads == 0, \ + "hidden_size must be divisible by num_attention_heads" + args = get_args() + tp_size = mpu.get_tensor_model_parallel_world_size() + self.sp_size = mpu.get_context_parallel_world_size() + self.num_attention_heads_per_partition = core.utils.divide( + config.num_attention_heads, tp_size) + self.num_attention_heads_per_partition_per_cp = core.utils.divide( + self.num_attention_heads_per_partition, self.sp_size) + self.hidden_size_per_attention_head = core.utils.divide( + config.hidden_size, config.num_attention_heads) + query_projection_size = config.kv_channels * config.num_attention_heads + kv_projection_size = args.kv_channels * args.num_attention_heads + + self.q_linear = tensor_parallel.ColumnParallelLinear( + config.hidden_size, + query_projection_size, + config=config, + init_method=config.init_method, + bias=True, + gather_output=False) + self.kv_linear = tensor_parallel.ColumnParallelLinear( + config.hidden_size, + 2 * kv_projection_size, + config=config, + init_method=config.init_method, + bias=True, + gather_output=False) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = tensor_parallel.RowParallelLinear( + query_projection_size, + config.hidden_size, + config=config, + init_method=config.output_layer_init_method, + bias=True, + input_is_parallel=True, + skip_bias_add=False) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x, cond, mask=None): + # query/value: img tokens; key: condition; mask: if padding tokens + cp_group = mpu.get_context_parallel_group() + B, N, C = x.shape + q, _ = self.q_linear(x) + # [B, N, C] --> [B, N, 2 * C] + kv, _ = self.kv_linear(cond) + if mpu.get_context_parallel_world_size() > 1: + N = N * self.sp_size + # q, k, v: [B, SUB_N, NUM_HEADS, HEAD_DIM] + q = q.view(1, -1, self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head) + # [B, N, 2 * C] --> [B, SUB_N, 2, NUM_HEADS, HEAD_DIM] + kv = kv.view(1, -1, 2, self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head) + k, v = kv.unbind(2) + + # apply all_to_all to gather sequence and split attention heads + q = all_to_all(q, cp_group, scatter_dim=2, gather_dim=1) + + k = split_forward_gather_backward( + k, cp_group, dim=2, grad_scale="down") + v = split_forward_gather_backward( + v, cp_group, dim=2, grad_scale="down") + else: + kv = kv.view(-1, self.num_attention_heads_per_partition, + 2 * self.hidden_size_per_attention_head) + # [B, N, 2 * C] --> 2 [B, N, C] + (k, v) = tensor_parallel.split_tensor_along_last_dim(kv, 2) + if is_npu_available() and x.dtype in [torch.float16, torch.bfloat16]: + if mpu.get_context_parallel_world_size() > 1: + q = q.view(-1, self.num_attention_heads_per_partition // self.sp_size, + self.hidden_size_per_attention_head) + k = k.view(-1, + self.num_attention_heads_per_partition // self.sp_size, + self.hidden_size_per_attention_head) + v = v.view(-1, + self.num_attention_heads_per_partition // self.sp_size, + self.hidden_size_per_attention_head) + else: + q = q.view(-1, self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head) + k = k.view(-1, self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head) + v = v.view(-1, self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head) + actual_seq_qlen = [] + actual_seq_kvlen = [] + if mask is not None: + ans = 0 + for _ in range(B): + ans += N + actual_seq_qlen.append(ans) + ans = 0 + for m in mask: + ans += m + actual_seq_kvlen.append(ans) + x = torch_npu.npu_fusion_attention( + q, + k, + v, + q.shape[-2], + input_layout="TND", + pse=None, + scale=1.0 / + math.sqrt( + self.hidden_size_per_attention_head), + pre_tockens=65536, + next_tockens=65536, + actual_seq_qlen=tuple(actual_seq_qlen), + actual_seq_kvlen=tuple(actual_seq_kvlen), + keep_prob=1. - + self.attn_drop.p, + sparse_mode=0, + )[0] + else: + q = q.view(1, -1, self.num_attention_heads_per_partition_per_cp, + self.hidden_size_per_attention_head) + kv = kv.view(1, -1, 2, self.num_attention_heads_per_partition_per_cp, + self.hidden_size_per_attention_head) + k, v = kv.unbind(2) + + attn_bias = None + if mask is not None: + attn_bias = xformers.ops.fmha.BlockDiagonalMask.from_seqlens([ + N] * B, mask) + x = xformers.ops.memory_efficient_attention( + q, k, v, p=self.attn_drop.p, attn_bias=attn_bias) + + if mpu.get_context_parallel_world_size() > 1: + # apply all to all to gather back attention heads and scatter + # sequence + x = x.view(1, -1, self.num_attention_heads_per_partition // self.sp_size, + self.hidden_size_per_attention_head) + x = all_to_all(x, cp_group, scatter_dim=1, gather_dim=2) + + new_shape = x.shape[:-2] + (x.shape[-2] * x.shape[-1],) + x = x.view(new_shape) + + x, _ = self.proj(x) + x = x.view(B, -1, C) + x = self.proj_drop(x) + return x + + +class T2IFinalLayer(nn.Module): + """ + The final layer of PixArt. + """ + + def __init__(self, hidden_size, num_patch, out_channels, config=None): + super().__init__() + self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = tensor_parallel.ColumnParallelLinear( + hidden_size, + num_patch * out_channels, + config=config, + init_method=config.init_method, + bias=True, + gather_output=True) + self.scale_shift_table = nn.Parameter(torch.randn(2, hidden_size) / hidden_size**0.5) + self.out_channels = out_channels + + def forward(self, x, t): + shift, scale = (self.scale_shift_table[None] + t[:, None]).chunk(2, dim=1) + x = t2i_modulate(self.norm_final(x), shift, scale) + x, _ = self.linear(x) + return x + + +# =============================================== +# Embedding Layers for Timesteps and Class Labels +# =============================================== + + +class TimestepEmbedder(nn.Module): + """ + Embeds scalar timesteps into vector representations. + """ + + def __init__(self, hidden_size, frequency_embedding_size=256, config=None): + super().__init__() + self.mlp = nn.Sequential( + tensor_parallel.ColumnParallelLinear( + frequency_embedding_size, + hidden_size, + config=config, + init_method=config.init_method, + bias=True, + gather_output=False), + nn.SiLU(), + tensor_parallel.RowParallelLinear( + hidden_size, + hidden_size, + config=config, + init_method=config.output_layer_init_method, + bias=True, + input_is_parallel=True, + skip_bias_add=False) + ) + self.frequency_embedding_size = frequency_embedding_size + + @staticmethod + def timestep_embedding(t, dim, max_period=10000): + """ + Create sinusoidal timestep embeddings. + :param t: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an (N, D) Tensor of positional embeddings. + """ + # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py + half = dim // 2 + freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half) + freqs = freqs.to(device=t.device) + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + def forward(self, t, dtype): + t_freq = self.timestep_embedding(t, self.frequency_embedding_size) + if t_freq.dtype != dtype: + t_freq = t_freq.to(dtype) + t_emb, _ = self.mlp[0](t_freq) + t_emb = self.mlp[1](t_emb) + t_emb, _ = self.mlp[2](t_emb) + return t_emb + + +class CaptionEmbedder(nn.Module): + """ + Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance. + """ + + def __init__(self, in_channels, hidden_size, uncond_prob, act_layer=nn.GELU(approximate="tanh"), token_num=120): + super().__init__() + self.y_proj = Mlp( + in_features=in_channels, hidden_features=hidden_size, out_features=hidden_size, act_layer=act_layer, drop=0 + ) + self.register_buffer("y_embedding", nn.Parameter(torch.randn(token_num, in_channels) / in_channels**0.5)) + self.uncond_prob = uncond_prob + + def token_drop(self, caption, force_drop_ids=None): + """ + Drops labels to enable classifier-free guidance. + """ + if force_drop_ids is None: + drop_ids = torch.rand(caption.shape[0]).cuda() < self.uncond_prob + else: + drop_ids = force_drop_ids == 1 + caption = torch.where(drop_ids[:, None, None, None], self.y_embedding, caption) + return caption + + def forward(self, caption, train, force_drop_ids=None): + if train: + assert caption.shape[2:] == self.y_embedding.shape + use_dropout = self.uncond_prob > 0 + if (train and use_dropout) or (force_drop_ids is not None): + caption = self.token_drop(caption, force_drop_ids) + caption = self.y_proj(caption) + return caption + + +# =============================================== +# Sine/Cosine Positional Embedding Functions +# =============================================== +# https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py + + +def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0, scale=1.0, base_size=None): + """ + grid_size: int of the grid height and width + return: + pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) + """ + if not isinstance(grid_size, tuple): + grid_size = (grid_size, grid_size) + + grid_h = np.arange(grid_size[0], dtype=np.float32) / scale + grid_w = np.arange(grid_size[1], dtype=np.float32) / scale + if base_size is not None: + grid_h *= base_size / grid_size[0] + grid_w *= base_size / grid_size[1] + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) + + grid = grid.reshape([2, 1, grid_size[1], grid_size[0]]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + if cls_token and extra_tokens > 0: + pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0) + return pos_embed + + +def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): + assert embed_dim % 2 == 0 + + # use half of dimensions to encode grid_h + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) + + emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) + return emb + + +def get_1d_sincos_pos_embed(embed_dim, length, scale=1.0): + pos = np.arange(0, length)[..., None] / scale + return get_1d_sincos_pos_embed_from_grid(embed_dim, pos) + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """ + embed_dim: output dimension for each position + pos: a list of positions to be encoded: size (M,) + out: (M, D) + """ + assert embed_dim % 2 == 0 + omega = np.arange(embed_dim // 2, dtype=np.float64) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + return emb diff --git a/PyTorch/built-in/mlm/OpenSora1.0/opensora/models/vae/__init__.py b/PyTorch/built-in/mlm/OpenSora1.0/opensora/models/vae/__init__.py index 63510b08b2..39fe5e3703 100644 --- a/PyTorch/built-in/mlm/OpenSora1.0/opensora/models/vae/__init__.py +++ b/PyTorch/built-in/mlm/OpenSora1.0/opensora/models/vae/__init__.py @@ -1 +1,2 @@ from .vae import VideoAutoencoderKL, VideoAutoencoderKLTemporalDecoder +from .vae_mindspeed import VideoAutoencoderKL_mindspeed \ No newline at end of file diff --git a/PyTorch/built-in/mlm/OpenSora1.0/opensora/models/vae/vae.py b/PyTorch/built-in/mlm/OpenSora1.0/opensora/models/vae/vae.py index 83e1b784ef..5e822952a3 100644 --- a/PyTorch/built-in/mlm/OpenSora1.0/opensora/models/vae/vae.py +++ b/PyTorch/built-in/mlm/OpenSora1.0/opensora/models/vae/vae.py @@ -5,8 +5,7 @@ from einops import rearrange from opensora.registry import MODELS from opensora.acceleration.communications import gather_forward_split_backward, split_forward_gather_backward -from megatron.core import mpu -import mindspeed +from opensora.acceleration.parallel_states import get_sequence_parallel_group @MODELS.register_module() @@ -22,23 +21,8 @@ class VideoAutoencoderKL(nn.Module): def encode(self, x): # x: (B, C, T, H, W) B = x.shape[0] - dim_size = x.size(2) - tp_cp_size = mindspeed.core.parallel_state.get_tensor_and_context_parallel_world_size() - cp_size = mpu.get_context_parallel_world_size() - enable_tp_cp_parallel = False - enable_cp_parallel = False if self.enable_sequence_parallelism: - if dim_size % tp_cp_size == 0: - x = split_forward_gather_backward( - x, mindspeed.core.parallel_state.get_tensor_and_context_parallel_group(), dim=2, grad_scale="down") - enable_tp_cp_parallel = True - elif dim_size % cp_size == 0: - x = split_forward_gather_backward(x, mpu.get_context_parallel_group(), dim=2, grad_scale="down") - enable_cp_parallel = True - else: - print(f"Warning: dim_size({dim_size}) is not divisible by tp_cp_size({tp_cp_size}) or " - f"cp_size({cp_size}), so parallelism is not used") - + x = split_forward_gather_backward(x, get_sequence_parallel_group(), dim=2, grad_scale="down") x = rearrange(x, "B C T H W -> (B T) C H W") if self.micro_batch_size is None: @@ -47,21 +31,14 @@ class VideoAutoencoderKL(nn.Module): bs = self.micro_batch_size x_out = [] for i in range(0, x.shape[0], bs): - x_bs = x[i: i + bs] + x_bs = x[i : i + bs] x_bs = self.module.encode(x_bs).latent_dist.sample().mul_(0.18215) x_out.append(x_bs) x = torch.cat(x_out, dim=0) - + x = rearrange(x, "(B T) C H W -> B C T H W", B=B) if self.enable_sequence_parallelism: - if enable_tp_cp_parallel: - x = gather_forward_split_backward( - x, mindspeed.core.parallel_state.get_tensor_and_context_parallel_group(), dim=2, grad_scale="up") - elif enable_cp_parallel: - x = gather_forward_split_backward(x, mpu.get_context_parallel_group(), dim=2, grad_scale="up") - else: - print(f"Warning: dim_size({dim_size}) is not divisible by tp_cp_size({tp_cp_size}) or " - f"cp_size({cp_size}), so parallelism is not used") + x = gather_forward_split_backward(x, get_sequence_parallel_group(), dim=2, grad_scale="up") return x diff --git a/PyTorch/built-in/mlm/OpenSora1.0/opensora/models/vae/vae_mindspeed.py b/PyTorch/built-in/mlm/OpenSora1.0/opensora/models/vae/vae_mindspeed.py new file mode 100644 index 0000000000..99265251eb --- /dev/null +++ b/PyTorch/built-in/mlm/OpenSora1.0/opensora/models/vae/vae_mindspeed.py @@ -0,0 +1,89 @@ +import torch +import torch.nn as nn +from diffusers.models import AutoencoderKL, AutoencoderKLTemporalDecoder +from einops import rearrange + +from opensora.registry import MODELS +from opensora.acceleration.communications import gather_forward_split_backward, split_forward_gather_backward +from megatron.core import mpu +import mindspeed + + +@MODELS.register_module("VideoAutoencoderKL_mindspeed") +class VideoAutoencoderKL_mindspeed(nn.Module): + def __init__(self, from_pretrained=None, micro_batch_size=None, enable_sequence_parallelism=False): + super().__init__() + self.module = AutoencoderKL.from_pretrained(from_pretrained) + self.out_channels = self.module.config.latent_channels + self.patch_size = (1, 8, 8) + self.micro_batch_size = micro_batch_size + self.enable_sequence_parallelism = enable_sequence_parallelism + + def encode(self, x): + # x: (B, C, T, H, W) + B = x.shape[0] + dim_size = x.size(2) + tp_cp_size = mindspeed.core.parallel_state.get_tensor_and_context_parallel_world_size() + cp_size = mpu.get_context_parallel_world_size() + enable_tp_cp_parallel = False + enable_cp_parallel = False + if self.enable_sequence_parallelism: + if dim_size % tp_cp_size == 0: + x = split_forward_gather_backward( + x, mindspeed.core.parallel_state.get_tensor_and_context_parallel_group(), dim=2, grad_scale="down") + enable_tp_cp_parallel = True + elif dim_size % cp_size == 0: + x = split_forward_gather_backward(x, mpu.get_context_parallel_group(), dim=2, grad_scale="down") + enable_cp_parallel = True + else: + print(f"Warning: dim_size({dim_size}) is not divisible by tp_cp_size({tp_cp_size}) or " + f"cp_size({cp_size}), so parallelism is not used") + + x = rearrange(x, "B C T H W -> (B T) C H W") + + if self.micro_batch_size is None: + x = self.module.encode(x).latent_dist.sample().mul_(0.18215) + else: + bs = self.micro_batch_size + x_out = [] + for i in range(0, x.shape[0], bs): + x_bs = x[i: i + bs] + x_bs = self.module.encode(x_bs).latent_dist.sample().mul_(0.18215) + x_out.append(x_bs) + x = torch.cat(x_out, dim=0) + + x = rearrange(x, "(B T) C H W -> B C T H W", B=B) + if self.enable_sequence_parallelism: + if enable_tp_cp_parallel: + x = gather_forward_split_backward( + x, mindspeed.core.parallel_state.get_tensor_and_context_parallel_group(), dim=2, grad_scale="up") + elif enable_cp_parallel: + x = gather_forward_split_backward(x, mpu.get_context_parallel_group(), dim=2, grad_scale="up") + else: + print(f"Warning: dim_size({dim_size}) is not divisible by tp_cp_size({tp_cp_size}) or " + f"cp_size({cp_size}), so parallelism is not used") + + return x + + def decode(self, x): + # x: (B, C, T, H, W) + B = x.shape[0] + x = rearrange(x, "B C T H W -> (B T) C H W") + if self.micro_batch_size is None: + x = self.module.decode(x / 0.18215).sample + else: + bs = self.micro_batch_size + x_out = [] + for i in range(0, x.shape[0], bs): + x_bs = x[i : i + bs] + x_bs = self.module.decode(x_bs / 0.18215).sample + x_out.append(x_bs) + x = torch.cat(x_out, dim=0) + x = rearrange(x, "(B T) C H W -> B C T H W", B=B) + return x + + def get_latent_size(self, input_size): + for i in range(3): + assert input_size[i] % self.patch_size[i] == 0, "Input size must be divisible by patch size" + input_size = [input_size[i] // self.patch_size[i] for i in range(3)] + return input_size diff --git a/PyTorch/built-in/mlm/OpenSora1.0/tests/mindspeed_test/16x256x256.py b/PyTorch/built-in/mlm/OpenSora1.0/tests/mindspeed_test/16x256x256.py new file mode 100644 index 0000000000..24ddf2eb64 --- /dev/null +++ b/PyTorch/built-in/mlm/OpenSora1.0/tests/mindspeed_test/16x256x256.py @@ -0,0 +1,31 @@ +num_frames = 16 +frame_interval = 3 +image_size = (256, 256) + +# Define dataset +root = None +use_image_transform = False +num_workers = 1 + +# Define model +model = dict( + type="STDiT-XL/2/mindspeed", + space_scale=0.5, + time_scale=1.0, + enable_layernorm_kernel=True, +) +vae = dict( + type="VideoAutoencoderKL_mindspeed", + from_pretrained="sd-vae-ft-ema", + enable_sequence_parallelism=True, +) +text_encoder = dict( + type="t5", + from_pretrained="DeepFloyd/t5-v1_1-xxl", + model_max_length=120, + shardformer=False, +) +scheduler = dict( + type="iddpm_mindspeed", + timestep_respacing="", +) \ No newline at end of file diff --git a/PyTorch/built-in/mlm/OpenSora1.0/tests/mindspeed_test/pretrain_opensora.py b/PyTorch/built-in/mlm/OpenSora1.0/tests/mindspeed_test/pretrain_opensora.py new file mode 100644 index 0000000000..1c0048124b --- /dev/null +++ b/PyTorch/built-in/mlm/OpenSora1.0/tests/mindspeed_test/pretrain_opensora.py @@ -0,0 +1,232 @@ +# Copyright (c) 2024, Huawei Technologies Co., Ltd. All rights reserved. + +from functools import partial + +from mmengine.config import Config +import torch + +import mindspeed.megatron_adaptor +from mindspeed.utils import get_batch_on_this_cp_rank + +from megatron.core import mpu +from megatron.core.enums import ModelType +from megatron.training import get_args, get_timers, print_rank_0, pretrain +from megatron.training.utils import average_losses_across_data_parallel_group +from megatron.training.arguments import core_transformer_config_from_args + +from opensora.utils.train_utils import update_ema +from opensora.utils.misc import requires_grad +from opensora.datasets import DatasetFromCSV, get_transforms_video, get_transforms_image, prepare_dataloader +from opensora.registry import MODELS, SCHEDULERS, build_module + +scheduler = None +vae = None +text_encoder = None +cfg = None + + +def initialize(): + args = get_args() + dtype = args.params_dtype + + def initialize_models(): + global cfg + global vae + global text_encoder + + vae_cfg = cfg.vae + text_encoder_cfg = cfg.text_encoder + vae = build_module(vae_cfg, MODELS) + text_encoder = build_module(text_encoder_cfg, MODELS, device=torch.cuda.current_device()) + + vae = vae.to(torch.cuda.current_device(), dtype) + + def initialize_scheduler(): + global scheduler + scheduler = build_module(cfg.scheduler, SCHEDULERS) + + def initialize_config(): + args = get_args() + global cfg + cfg = Config.fromfile(args.additional_config) + + initialize_config() + initialize_scheduler() + initialize_models() + + +def initialize_pipeline_tensor_shapes(hidden_size): + args = get_args() + micro_batch_size = args.micro_batch_size + dtype = args.params_dtype + latent_size = vae.get_latent_size((cfg.num_frames, *cfg.image_size)) + text_encoder_maxlen = text_encoder.model_max_length + args.pipeline_tensor_shapes = [ + {'shape': (micro_batch_size, text_encoder.output_dim, hidden_size), 'dtype': dtype}, + {'shape': (micro_batch_size, vae.out_channels, *latent_size), 'dtype': dtype}, + {'shape': (micro_batch_size, 1, text_encoder_maxlen, hidden_size), 'dtype': dtype}, + {'shape': (micro_batch_size,), 'dtype': torch.float32}, + {'shape': (micro_batch_size, hidden_size * 6), 'dtype': dtype}, + {'shape': (micro_batch_size, text_encoder_maxlen), 'dtype': torch.float32}, + {'shape': (micro_batch_size, vae.out_channels, *latent_size), 'dtype': dtype}, + {'shape': (micro_batch_size, vae.out_channels, *latent_size), 'dtype': dtype}, + {'shape': (micro_batch_size, hidden_size), 'dtype': dtype} + ] + setattr(forward_step, 'pipeline_tensor_shapes', args.pipeline_tensor_shapes) + + +def model_provider(pre_process=True, post_process=True): + """Build the model.""" + args = get_args() + dtype = args.params_dtype + latent_size = vae.get_latent_size((cfg.num_frames, *cfg.image_size)) + stdit = build_module( + cfg.model, + MODELS, + input_size=latent_size, + in_channels=vae.out_channels, + caption_channels=text_encoder.output_dim, + model_max_length=text_encoder.model_max_length, + dtype=dtype, + ) + + ema = build_module( + cfg.model, + MODELS, + input_size=latent_size, + in_channels=vae.out_channels, + caption_channels=text_encoder.output_dim, + model_max_length=text_encoder.model_max_length, + dtype=dtype, + ) + + requires_grad(ema, False) + stdit.ema = ema + update_ema(ema, stdit, decay=0, sharded=False) + ema.eval() + + initialize_pipeline_tensor_shapes(stdit.hidden_size) + stdit.config = core_transformer_config_from_args(get_args()) + return stdit + + +def get_batch_on_this_tp_rank(data_iterator): + global vae + global text_encoder + args = get_args() + dtype = args.params_dtype + + if data_iterator is not None: + batch = next(data_iterator) + else: + batch = None + # x.shape: [B, C, T, H/P, W/P] + x = batch['video'].to(torch.cuda.current_device(), dtype) + y = batch['text'] + + with torch.no_grad(): + # Prepare visual inputs + # x.shape: [B, C, T, H/P, W/P] + x = vae.encode(x).contiguous() + # Prepare text inputs + encoded_text = text_encoder.encode(y) + y = encoded_text['y'].contiguous() + mask = encoded_text['mask'].contiguous() + + batch = { + 'x': x, + 'y': y, + 'mask': mask + } + return batch + + +def get_batch(data_iterator): + """Build the batch.""" + + if mpu.is_pipeline_first_stage(): + batch = get_batch_on_this_tp_rank(data_iterator) + return batch['x'], batch['y'], batch['mask'] + else: + return None, None, None + + +def loss_func(x_t, x_0, t, noise, output_tensor): + loss_dict = scheduler.training_losses(output_tensor[0], x_t, x_0, t, noise=noise) + loss = loss_dict["loss"].mean() + averaged_loss = average_losses_across_data_parallel_group([loss]) + loss = loss.unsqueeze(0) + return loss, {"loss": averaged_loss[0]} + + +def forward_step(data_iterator, model): + """Forward step.""" + args = get_args() + dtype = args.params_dtype + + # Get the batch. + x, y, mask = get_batch(data_iterator) + + num_timesteps = 1000 + micro_bs = args.micro_batch_size + timestep = None + x_0 = None + noise = None + + if mpu.is_pipeline_first_stage(): + x_0 = x.clone() + timestep = torch.randint(0, num_timesteps, (micro_bs,), device=torch.cuda.current_device(), dtype=torch.int64) + noise = torch.randn_like(x) + noise = noise.to(device=torch.cuda.current_device(), dtype=dtype) + x = scheduler.q_sample(x, timestep, noise=noise) + x_t = x.clone() + + if mpu.get_pipeline_model_parallel_world_size() > 1: + x, x_t, y, timestep, t0, mask, x_0, noise, t = model(x, timestep, y, x_0, noise, mask) + output_tensor_wrap = [x, x_t, y, timestep, t0, mask, x_0, noise, t] + else: + x = model(x, timestep, y, x_0, noise, mask) + output_tensor_wrap = [x] + + return output_tensor_wrap, partial(loss_func, x_t, x_0, timestep, noise) + + +def train_valid_test_datasets_provider(train_val_test_num_samples): + """Build train, valid, and test datasets.""" + args = get_args() + + dataset = DatasetFromCSV( + args.data_path[0], + transform=( + get_transforms_video(cfg.image_size[0]) + if not cfg.use_image_transform + else get_transforms_image(cfg.image_size[0]) + ), + num_frames=cfg.num_frames, + frame_interval=cfg.frame_interval, + root=cfg.root, + ) + + dataloader = prepare_dataloader( + dataset, + batch_size=args.micro_batch_size, + num_workers=cfg.num_workers, + shuffle=True, + drop_last=True, + pin_memory=True, + process_group=mpu.get_data_parallel_group(), + ) + + return iter(dataloader), None, None + + +if __name__ == "__main__": + train_valid_test_datasets_provider.is_distributed = True + pretrain( + train_valid_test_datasets_provider, + model_provider, + ModelType.encoder_or_decoder, + forward_step, + args_defaults={'dataloader_type': 'external', + 'init_func': initialize} + ) diff --git a/PyTorch/built-in/mlm/OpenSora1.0/tests/mindspeed_test/pretrain_opensora.sh b/PyTorch/built-in/mlm/OpenSora1.0/tests/mindspeed_test/pretrain_opensora.sh new file mode 100644 index 0000000000..5731a4f13a --- /dev/null +++ b/PyTorch/built-in/mlm/OpenSora1.0/tests/mindspeed_test/pretrain_opensora.sh @@ -0,0 +1,87 @@ + +export CUDA_DEVICE_MAX_CONNECTIONS=1 + +GPUS_PER_NODE=8 +MASTER_ADDR=localhost +MASTER_PORT=6900 +NNODES=1 +NODE_RANK=0 +WORLD_SIZE=$(($GPUS_PER_NODE*$NNODES)) + +CKPT_SAVE_DIR="your model save ckpt path" +CKPT_LOAD_DIR="your model load ckpt path" +DATA_PATH="your data path" +SCRIPT_CONFIG="16x256x256.py" + +TP=2 +PP=2 +CP=2 + +DISTRIBUTED_ARGS=" + --nproc_per_node $GPUS_PER_NODE \ + --nnodes $NNODES \ + --node_rank $NODE_RANK \ + --master_addr $MASTER_ADDR \ + --master_port $MASTER_PORT +" + +GPT_ARGS=" + --tensor-model-parallel-size ${TP} \ + --use-multiparameter-pipeline-model-parallel \ + --pipeline-model-parallel-size ${PP} \ + --context-parallel-size ${CP} \ + --context-parallel-algo ulysses_cp_algo \ + --micro-batch-size 4 \ + --global-batch-size 4 \ + --num-layers 28 \ + --hidden-size 1152 \ + --num-attention-heads 16 \ + --seq-length 1024\ + --max-position-embeddings 1024 \ + --attention-dropout 0.0 \ + --hidden-dropout 0.0 \ + --tokenizer-type NullTokenizer \ + --vocab-size 0 \ + --position-embedding-type rope \ + --rotary-base 500000 \ + --swiglu \ + --no-masked-softmax-fusion \ + --lr 2e-5 \ + --min-lr 2e-5 \ + --train-iters 2500 \ + --weight-decay 0 \ + --weight-decay 0.0 \ + --clip-grad 0.0 \ + --adam-beta1 0.9 \ + --adam-beta2 0.95 \ + --initial-loss-scale 4096 \ + --no-gradient-accumulation-fusion \ + --no-load-optim \ + --no-load-rng \ + --bf16 \ + --use-ema +" + +DATA_ARGS=" + --data-path $DATA_PATH \ +" + +MODEL_ARGS=" + --additional-config $SCRIPT_CONFIG \ +" + +OUTPUT_ARGS=" + --log-interval 1 \ + --save-interval 10000 \ + --eval-interval 10000 \ + --eval-iters 10 \ +" +torchrun $DISTRIBUTED_ARGS pretrain_opensora.py \ + $GPT_ARGS \ + $DATA_ARGS \ + $MODEL_ARGS \ + $OUTPUT_ARGS \ + --distributed-backend nccl \ + --save ${CKPT_SAVE_DIR} \ + +set +x -- Gitee