diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..aab021059e45449b8fdca0db1c819e119be32e50 --- /dev/null +++ b/.gitignore @@ -0,0 +1,13 @@ +__pycache__/ +kernel_meta/ +prof*/ +*.egg-info +*.mp4 +*.jpg +*.png +models/ +*.msc +*.mv +fusion_result.json +*.log +*.tar.gz \ No newline at end of file diff --git a/diffsynth/__init__.py b/diffsynth/__init__.py index ae0a45c2e2dc61f8f16354feb1b0c481776b523f..99dea25add19cc2769e00632c6ae05394a3fe51c 100644 --- a/diffsynth/__init__.py +++ b/diffsynth/__init__.py @@ -4,3 +4,4 @@ from .prompters import * from .schedulers import * from .pipelines import * from .controlnets import * +from .npu_utils import * \ No newline at end of file diff --git a/diffsynth/distributed/xdit_context_parallel.py b/diffsynth/distributed/xdit_context_parallel.py index 4887e2f16a87fa3a28e8057de41c6537be83f8ed..a3d9f9cd0f4a1c4a10ee505ac77b9e723cd4981c 100644 --- a/diffsynth/distributed/xdit_context_parallel.py +++ b/diffsynth/distributed/xdit_context_parallel.py @@ -1,14 +1,31 @@ +import os import torch from typing import Optional from einops import rearrange -from xfuser.core.distributed import (get_sequence_parallel_rank, - get_sequence_parallel_world_size, - get_sp_group) -from xfuser.core.long_ctx_attention import xFuserLongContextAttention + +if torch.npu.is_available(): + ASCEND_NPU_AVAILABLE = True + from diffsynth.npu_utils.distributed.parallel_mgr import ( + get_sequence_parallel_rank, + get_sequence_parallel_world_size, + get_sp_group) + from diffsynth.npu_utils.modules.attn_layer import xFuserLongContextAttention +else: + from xfuser.core.distributed import (get_sequence_parallel_rank, + get_sequence_parallel_world_size, + get_sp_group) + from xfuser.core.long_ctx_attention import xFuserLongContextAttention + +try: + from mindiesd import rotary_position_embedding + NPU_ROPE_AVAILABLE = True +except: + NPU_ROPE_AVAILABLE = False def sinusoidal_embedding_1d(dim, position): - sinusoid = torch.outer(position.type(torch.float64), torch.pow( - 10000, -torch.arange(dim//2, dtype=torch.float64, device=position.device).div(dim//2))) + position_dtype = torch.float32 if torch.npu.is_available() else torch.float64 + sinusoid = torch.outer(position.type(position_dtype), torch.pow( + 10000, -torch.arange(dim//2, dtype=position_dtype, device=position.device).div(dim//2))) x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1) return x.to(position.dtype) @@ -19,12 +36,36 @@ def pad_freqs(original_tensor, target_len): pad_size, s1, s2, - dtype=original_tensor.dtype, + dtype=torch.float32 if torch.npu.is_available() else original_tensor.dtype, device=original_tensor.device) padded_tensor = torch.cat([original_tensor, padding_tensor], dim=0) return padded_tensor -def rope_apply(x, freqs, num_heads): +def rope_apply_npu(x, freqs, num_heads): + # print('rope_apply under xdit_context_parallel.py') + if NPU_ROPE_AVAILABLE: + x = rearrange(x, "b s (n d) -> b s n d", n=num_heads) + cos, sin = freqs + x_out = rotary_position_embedding(x, cos, sin, rotated_mode="rotated_interleaved", fused=True) + return x_out.flatten(2).to(x.dtype) + else: + x = rearrange(x, "b s (n d) -> b s n d", n=num_heads) + s_per_rank = x.shape[1] + + dtype = torch.float32 + + x_out = torch.view_as_complex(x.to(dtype).reshape( + x.shape[0], x.shape[1], x.shape[2], -1, 2)) + + sp_size = get_sequence_parallel_world_size() + sp_rank = get_sequence_parallel_rank() + freqs = pad_freqs(freqs, s_per_rank * sp_size) + freqs_rank = freqs[(sp_rank * s_per_rank):((sp_rank + 1) * s_per_rank), :, :] + + x_out = torch.view_as_real(x_out * freqs_rank).flatten(2) + return x_out.to(x.dtype) + +def rope_apply_cuda(x, freqs, num_heads): x = rearrange(x, "b s (n d) -> b s n d", n=num_heads) s_per_rank = x.shape[1] @@ -39,6 +80,12 @@ def rope_apply(x, freqs, num_heads): x_out = torch.view_as_real(x_out * freqs_rank).flatten(2) return x_out.to(x.dtype) +def rope_apply(x, freqs, num_heads): + if torch.npu.is_available(): + return rope_apply_npu(x, freqs, num_heads) + else: + return rope_apply_cuda(x, freqs, num_heads) + def usp_dit_forward(self, x: torch.Tensor, timestep: torch.Tensor, @@ -126,6 +173,41 @@ def usp_attn_forward(self, x, freqs): ) x = x.flatten(2) - del q, k, v - torch.cuda.empty_cache() - return self.o(x) \ No newline at end of file + if not torch.npu.is_available(): + del q, k, v + torch.cuda.empty_cache() + return self.o(x) + +def usp_dit_forward_vace(self, + x: torch.Tensor, + vace_context, + **kwargs, + ): + + """ + + c = block(c, x, context, t_mod, freqs) + hints = torch.unbind(c)[:-1] + return hints + """ + c = [self.vace_patch_embedding(u.unsqueeze(0)) for u in vace_context] + c = [u.flatten(2).transpose(1, 2) for u in c] + c = torch.cat([ + torch.cat([u, u.new_zeros(1, x.shape[1] * get_sequence_parallel_world_size() - u.size(1), u.size(2))], dim=1) + for u in c + ]) + + # arguments + new_kwargs = dict(x=x) + new_kwargs.update(kwargs) + + # Context Parallel + c = torch.chunk( + c, get_sequence_parallel_world_size(), + dim=1)[get_sequence_parallel_rank()] + + for block in self.vace_blocks: + c = block(c, **new_kwargs) + + hints = torch.unbind(c)[:-1] + return hints \ No newline at end of file diff --git a/diffsynth/models/wan_video_dit.py b/diffsynth/models/wan_video_dit.py index 419f8cf855780d96884405155f0652b8c79c178b..dc8face38851b781e2f219b4777eba8ae676fe0f 100644 --- a/diffsynth/models/wan_video_dit.py +++ b/diffsynth/models/wan_video_dit.py @@ -1,4 +1,7 @@ +import os import torch +import torch_npu +import torch.distributed as dist import torch.nn as nn import torch.nn.functional as F import math @@ -66,9 +69,21 @@ def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor): def sinusoidal_embedding_1d(dim, position): - sinusoid = torch.outer(position.type(torch.float64), torch.pow( - 10000, -torch.arange(dim//2, dtype=torch.float64, device=position.device).div(dim//2))) - x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1) + if torch.npu.is_available(): + # preprocess + assert dim % 2 == 0 + half = dim // 2 + position = position.type(torch.float32) + + # calculation + sinusoid = torch.outer( + position, torch.pow(10000, -torch.arange(half).to(position).div(half))) + x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1) + return x.float() + else: + sinusoid = torch.outer(position.type(torch.float64), torch.pow( + 10000, -torch.arange(dim//2, dtype=torch.float64, device=position.device).div(dim//2))) + x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1) return x.to(position.dtype) @@ -82,10 +97,21 @@ def precompute_freqs_cis_3d(dim: int, end: int = 1024, theta: float = 10000.0): def precompute_freqs_cis(dim: int, end: int = 1024, theta: float = 10000.0): # 1d rope precompute - freqs = 1.0 / (theta ** (torch.arange(0, dim, 2) - [: (dim // 2)].double() / dim)) - freqs = torch.outer(torch.arange(end, device=freqs.device), freqs) - freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 + if torch.npu.is_available(): + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2) + [: (dim // 2)].to(torch.float32) / dim)) + if dist.is_initialized(): + if freqs.device != torch.device(f"npu:{dist.get_rank()}"): + freqs = freqs.to(torch.device(f"npu:{dist.get_rank()}")) + else: + freqs = freqs.to(torch.device("npu")) + freqs = torch.outer(torch.arange(end, device=freqs.device), freqs) + freqs_cis = torch.polar(torch.ones_like(freqs), freqs).to(torch.complex64) # complex64 + else: + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2) + [: (dim // 2)].double() / dim)) + freqs = torch.outer(torch.arange(end, device=freqs.device), freqs) + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 return freqs_cis @@ -108,8 +134,25 @@ class RMSNorm(nn.Module): def forward(self, x): dtype = x.dtype - return self.norm(x.float()).to(dtype) * self.weight + if torch.npu.is_available(): + return torch_npu.npu_rms_norm(x, self.weight, epsilon=self.eps)[0] + else: + return self.norm(x.float()).to(dtype) * self.weight + +class LayerNormNPU(nn.LayerNorm): + def __init__(self, dim, eps=1e-6, elementwise_affine=False): + super().__init__(dim, elementwise_affine=elementwise_affine, eps=eps) + self.dim = dim + + def forward(self, x): + r""" + Args: + x(Tensor): Shape [B, L, C] + """ + return torch_npu.npu_layer_norm_eval( + x, normalized_shape=[self.dim], weight=self.weight, bias=self.bias, eps=self.eps, + ) class AttentionModule(nn.Module): def __init__(self, num_heads): @@ -203,9 +246,9 @@ class DiTBlock(nn.Module): self.self_attn = SelfAttention(dim, num_heads, eps) self.cross_attn = CrossAttention( dim, num_heads, eps, has_image_input=has_image_input) - self.norm1 = nn.LayerNorm(dim, eps=eps, elementwise_affine=False) - self.norm2 = nn.LayerNorm(dim, eps=eps, elementwise_affine=False) - self.norm3 = nn.LayerNorm(dim, eps=eps) + self.norm1 = LayerNormNPU(dim, eps) if torch.npu.is_available() else nn.LayerNorm(dim, eps=eps, elementwise_affine=False) + self.norm2 = LayerNormNPU(dim, eps) if torch.npu.is_available() else nn.LayerNorm(dim, eps=eps, elementwise_affine=False) + self.norm3 = LayerNormNPU(dim, eps, elementwise_affine=True) if torch.npu.is_available() else nn.LayerNorm(dim, eps=eps) self.ffn = nn.Sequential(nn.Linear(dim, ffn_dim), nn.GELU( approximate='tanh'), nn.Linear(ffn_dim, dim)) self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5) @@ -234,11 +277,11 @@ class MLP(torch.nn.Module): def __init__(self, in_dim, out_dim, has_pos_emb=False): super().__init__() self.proj = torch.nn.Sequential( - nn.LayerNorm(in_dim), + LayerNormNPU(dim=in_dim) if torch.npu.is_available() else nn.LayerNorm(in_dim), nn.Linear(in_dim, in_dim), nn.GELU(), nn.Linear(in_dim, out_dim), - nn.LayerNorm(out_dim) + LayerNormNPU(dim=out_dim) if torch.npu.is_available() else nn.LayerNorm(out_dim) ) self.has_pos_emb = has_pos_emb if has_pos_emb: @@ -255,7 +298,7 @@ class Head(nn.Module): super().__init__() self.dim = dim self.patch_size = patch_size - self.norm = nn.LayerNorm(dim, eps=eps, elementwise_affine=False) + self.norm = LayerNormNPU(dim, eps, elementwise_affine=False) if torch.npu.is_available() else nn.LayerNorm(dim, eps=eps, elementwise_affine=False) self.head = nn.Linear(dim, out_dim * math.prod(patch_size)) self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5) @@ -301,6 +344,7 @@ class WanModel(torch.nn.Module): self.require_vae_embedding = require_vae_embedding self.require_clip_embedding = require_clip_embedding self.fuse_vae_embedding_in_latents = fuse_vae_embedding_in_latents + self.num_heads = num_heads self.patch_embedding = nn.Conv3d( in_dim, dim, kernel_size=patch_size, stride=patch_size) @@ -323,6 +367,7 @@ class WanModel(torch.nn.Module): self.head = Head(dim, out_dim, patch_size, eps) head_dim = dim // num_heads self.freqs = precompute_freqs_cis_3d(head_dim) + self.freqs_list = None if has_image_input: self.img_emb = MLP(1280, dim, has_pos_emb=has_image_pos_emb) # clip_feature_dim = 1280 diff --git a/diffsynth/models/wan_video_image_encoder.py b/diffsynth/models/wan_video_image_encoder.py index 5ca878b1fd6ed6dc00420f092f87479fb65ef63a..8f0d28071fcb3859c4199364883f947f01822198 100644 --- a/diffsynth/models/wan_video_image_encoder.py +++ b/diffsynth/models/wan_video_image_encoder.py @@ -5,6 +5,7 @@ Concise re-implementation of """ import math import torch +import torch_npu import torch.nn as nn import torch.nn.functional as F import torchvision.transforms as T @@ -50,6 +51,20 @@ class SelfAttention(nn.Module): return x +class LayerNormNPU(nn.LayerNorm): + def __init__(self, dim, eps=1e-6, elementwise_affine=False): + super().__init__(dim, elementwise_affine=elementwise_affine, eps=eps) + self.dim = dim + + def forward(self, x): + r""" + Args: + x(Tensor): Shape [B, L, C] + """ + return torch_npu.npu_layer_norm_eval( + x, normalized_shape=[self.dim], weight=self.weight, bias=self.bias, eps=self.eps, + ) + class AttentionBlock(nn.Module): def __init__(self, dim, num_heads, post_norm, dropout=0.1, eps=1e-5): @@ -61,11 +76,11 @@ class AttentionBlock(nn.Module): # layers self.attn = SelfAttention(dim, num_heads, dropout, eps) - self.norm1 = nn.LayerNorm(dim, eps=eps) + self.norm1 = LayerNormNPU(dim, eps) if torch.npu.is_available() else nn.LayerNorm(dim, eps=eps) self.ffn = nn.Sequential( nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim), nn.Dropout(dropout)) - self.norm2 = nn.LayerNorm(dim, eps=eps) + self.norm2 = LayerNormNPU(dim, eps) if torch.npu.is_available() else nn.LayerNorm(dim, eps=eps) def forward(self, x, mask): if self.post_norm: @@ -117,7 +132,7 @@ class XLMRoberta(nn.Module): ]) # norm layer - self.norm = nn.LayerNorm(dim, eps=eps) + self.norm = LayerNormNPU(dim, eps) if torch.npu.is_available() else nn.LayerNorm(dim, eps=eps) def forward(self, ids): """ @@ -308,10 +323,10 @@ class AttentionBlock(nn.Module): self.norm_eps = norm_eps # layers - self.norm1 = LayerNorm(dim, eps=norm_eps) + self.norm1 = LayerNormNPU(dim, norm_eps) if torch.npu.is_available() else LayerNorm(dim, eps=norm_eps) self.attn = SelfAttention(dim, num_heads, causal, attn_dropout, proj_dropout) - self.norm2 = LayerNorm(dim, eps=norm_eps) + self.norm2 = LayerNormNPU(dim, norm_eps) if torch.npu.is_available() else LayerNorm(dim, eps=norm_eps) if activation == 'swi_glu': self.mlp = SwiGLU(dim, int(dim * mlp_ratio)) else: @@ -354,7 +369,7 @@ class AttentionPool(nn.Module): self.to_q = nn.Linear(dim, dim) self.to_kv = nn.Linear(dim, dim * 2) self.proj = nn.Linear(dim, dim) - self.norm = LayerNorm(dim, eps=norm_eps) + self.norm = LayerNormNPU(dim, norm_eps) if torch.npu.is_available() else LayerNorm(dim, eps=norm_eps) self.mlp = nn.Sequential( nn.Linear(dim, int(dim * mlp_ratio)), QuickGELU() if activation == 'quick_gelu' else nn.GELU(), @@ -436,13 +451,16 @@ class VisionTransformer(nn.Module): self.dropout = nn.Dropout(embedding_dropout) # transformer - self.pre_norm = LayerNorm(dim, eps=norm_eps) if pre_norm else None + if pre_norm: + self.pre_norm = LayerNormNPU(dim, norm_eps) if torch.npu.is_available() else LayerNorm(dim, eps=norm_eps) + else: + self.pre_norm = None self.transformer = nn.Sequential(*[ AttentionBlock(dim, mlp_ratio, num_heads, post_norm, False, activation, attn_dropout, proj_dropout, norm_eps) for _ in range(num_layers) ]) - self.post_norm = LayerNorm(dim, eps=norm_eps) + self.post_norm = LayerNormNPU(dim, norm_eps) if torch.npu.is_available() else LayerNorm(dim, eps=norm_eps) # head if pool_type == 'token': diff --git a/diffsynth/models/wan_video_vae.py b/diffsynth/models/wan_video_vae.py index 397a2e7b66258159d84ac299b7798c52bc7e038a..78fae9fa1accd4f78b505349037494458ceec94f 100644 --- a/diffsynth/models/wan_video_vae.py +++ b/diffsynth/models/wan_video_vae.py @@ -6,7 +6,11 @@ import torch.nn.functional as F from tqdm import tqdm CACHE_T = 2 +ENABLE_VAE_PATCH_PARALLEL = False +def enable_vae_patch_parallel(): + global ENABLE_VAE_PATCH_PARALLEL + ENABLE_VAE_PATCH_PARALLEL = True def check_is_instance(model, module_class): if isinstance(model, module_class): @@ -37,9 +41,15 @@ class CausalConv3d(nn.Conv3d): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self._padding = (self.padding[2], self.padding[2], self.padding[1], - self.padding[1], 2 * self.padding[0], 0) - self.padding = (0, 0, 0) + + if ENABLE_VAE_PATCH_PARALLEL: + self._padding = (0, 0, 0, + 0, 2 * self.padding[0], 0) + self.padding = (0, self.padding[1], self.padding[2]) + else: + self._padding = (self.padding[2], self.padding[2], self.padding[1], + self.padding[1], 2 * self.padding[0], 0) + self.padding = (0, 0, 0) def forward(self, x, cache_x=None): padding = list(self._padding) diff --git a/diffsynth/npu_utils/distributed/comm.py b/diffsynth/npu_utils/distributed/comm.py new file mode 100644 index 0000000000000000000000000000000000000000..287241e5b4de4edea35a214e723248c58233dcdf --- /dev/null +++ b/diffsynth/npu_utils/distributed/comm.py @@ -0,0 +1,95 @@ +import torch + +import torch.distributed as dist + + +def all_to_all_4D( + input_: torch.tensor, scatter_idx: int = 2, gather_idx: int = 1, group=None, use_sync: bool = False +) -> torch.tensor: + """ + all-to-all for QKV + + Args: + input_ (torch.tensor): a tensor sharded along dim scatter dim + scatter_idx (int): default 1 + gather_idx (int): default 2 + group : torch process group + use_sync (bool): whether to synchronize after all-to-all + + Returns: + torch.tensor: resharded tensor (bs, seqlen/P, hc, hs) + """ + assert ( + input_.dim() == 4 + ), f"input_ must be 4D tensor, got {input_.dim()} and shape {input_.shape}" + + seq_world_size = dist.get_world_size(group) + + if scatter_idx == 2 and gather_idx == 1: + # input_ (torch.tensor): a tensor sharded along dim 1 (bs, seqlen/P, hc, hs) output: (bs, seqlen, hc/P, hs) + bs, shard_seqlen, hc, hs = input_.shape + seqlen = shard_seqlen * seq_world_size + shard_hc = hc // seq_world_size + + # transpose groups of heads with the seq-len parallel dimension, so that we can scatter them! + # (bs, seqlen/P, hc, hs) -reshape-> (bs, seq_len/P, P, hc/P, hs) -transpose(0,2)-> (P, seq_len/P, bs, hc/P, hs) + input_t = ( + input_.reshape(bs, shard_seqlen, seq_world_size, shard_hc, hs) + .transpose(0, 2) + .contiguous() + ) + + output = torch.empty_like(input_t) + # https://pytorch.org/docs/stable/distributed.html#torch.distributed.all_to_all_single + # (P, seq_len/P, bs, hc/P, hs) scatter seqlen -all2all-> (P, seq_len/P, bs, hc/P, hs) scatter head + + if seq_world_size > 1: + dist.all_to_all_single(output, input_t, group=group) + if use_sync: + torch.npu.synchronize() + else: + output = input_t + # if scattering the seq-dim, transpose the heads back to the original dimension + output = output.reshape(seqlen, bs, shard_hc, hs) + + # (seq_len, bs, hc/P, hs) -reshape-> (bs, seq_len, hc/P, hs) + output = output.transpose(0, 1).contiguous().reshape(bs, seqlen, shard_hc, hs) + + return output + + elif scatter_idx == 1 and gather_idx == 2: + # input_ (torch.tensor): a tensor sharded along dim 1 (bs, seqlen, hc/P, hs) output: (bs, seqlen/P, hc, hs) + bs, seqlen, shard_hc, hs = input_.shape + hc = shard_hc * seq_world_size + shard_seqlen = seqlen // seq_world_size + seq_world_size = dist.get_world_size(group) + + # transpose groups of heads with the seq-len parallel dimension, so that we can scatter them! + # (bs, seqlen, hc/P, hs) -reshape-> (bs, P, seq_len/P, hc/P, hs) -transpose(0, 3)-> (hc/P, P, seqlen/P, bs, hs) -transpose(0, 1) -> (P, hc/P, seqlen/P, bs, hs) + input_t = ( + input_.reshape(bs, seq_world_size, shard_seqlen, shard_hc, hs) + .transpose(0, 3) + .transpose(0, 1) + .contiguous() + .reshape(seq_world_size, shard_hc, shard_seqlen, bs, hs) + ) + + output = torch.empty_like(input_t) + # https://pytorch.org/docs/stable/distributed.html#torch.distributed.all_to_all_single + # (P, bs x hc/P, seqlen/P, hs) scatter seqlen -all2all-> (P, bs x seq_len/P, hc/P, hs) scatter head + if seq_world_size > 1: + dist.all_to_all_single(output, input_t, group=group) + if use_sync: + torch.npu.synchronize() + else: + output = input_t + + # if scattering the seq-dim, transpose the heads back to the original dimension + output = output.reshape(hc, shard_seqlen, bs, hs) + + # (hc, seqlen/N, bs, hs) -tranpose(0,2)-> (bs, seqlen/N, hc, hs) + output = output.transpose(0, 2).contiguous().reshape(bs, shard_seqlen, hc, hs) + + return output + else: + raise RuntimeError("scatter_idx must be 1 or 2 and gather_idx must be 1 or 2") diff --git a/diffsynth/npu_utils/distributed/fsdp.py b/diffsynth/npu_utils/distributed/fsdp.py new file mode 100644 index 0000000000000000000000000000000000000000..46fcd21624624257401f41c15e93a20998f0f5c7 --- /dev/null +++ b/diffsynth/npu_utils/distributed/fsdp.py @@ -0,0 +1,33 @@ +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +from functools import partial + +import torch +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.distributed.fsdp import MixedPrecision, ShardingStrategy +from torch.distributed.fsdp.wrap import lambda_auto_wrap_policy + + +def shard_model( + model, + device_id, + # param_dtype=torch.float32, + param_dtype=torch.bfloat16, + reduce_dtype=torch.float32, + buffer_dtype=torch.float32, + process_group=None, + sharding_strategy=ShardingStrategy.FULL_SHARD, + sync_module_states=True, +): + model = FSDP( + module=model, + process_group=process_group, + sharding_strategy=sharding_strategy, + auto_wrap_policy=partial( + lambda_auto_wrap_policy, lambda_fn=lambda m: m in model.blocks), + mixed_precision=MixedPrecision( + param_dtype=param_dtype, + reduce_dtype=reduce_dtype, + buffer_dtype=buffer_dtype), + device_id=device_id, + sync_module_states=sync_module_states) + return model diff --git a/diffsynth/npu_utils/distributed/group_coordinator.py b/diffsynth/npu_utils/distributed/group_coordinator.py new file mode 100644 index 0000000000000000000000000000000000000000..796bfe1a20ad0cb82aaf706ad626507c66a070f2 --- /dev/null +++ b/diffsynth/npu_utils/distributed/group_coordinator.py @@ -0,0 +1,597 @@ +# Copyright 2024 xDiT team. +# Adapted from +# https://github.com/vllm-project/vllm/blob/main/vllm/distributed/parallel_state.py +# Copyright 2023 The vLLM team. +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +from collections import namedtuple +from typing import Any, Dict, List, Optional, Tuple, Union +import pickle + +import torch +import torch_npu +import torch.distributed +from torch.distributed import Backend, ProcessGroup + +import logging + + +TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"]) + + +def _split_tensor_dict( + tensor_dict: Dict[str, Union[torch.Tensor, Any]], prefix: str = "" +) -> Tuple[List[Tuple[str, Any]], List[torch.Tensor]]: + """Split the tensor dictionary into two parts: + 1. A list of (key, value) pairs. If the value is a tensor, it is replaced + by its metadata. + 2. A list of tensors. + + If the Tensor is nested under `tensor_dict["key1"]["key2"]`, the key of its + metadata will be "key1%key2". + """ + metadata_list: List[Tuple[str, Any]] = [] + tensor_list = [] + for key, value in tensor_dict.items(): + if "%" in key: + logging.error( + "Avoid having '%' in key " + "as it is used as a separator for nested entries." + ) + if isinstance(value, torch.Tensor): + # Note: we cannot use `value.device` here, + # because it contains not only the device type but also the device + # index (e.g. "npu:0"). We only need the device type. + # receiving side will set the device index. + device = value.device.type + metadata_list.append( + (prefix + key, TensorMetadata(device, value.dtype, value.size())) + ) + tensor_list.append(value) + elif isinstance(value, dict): + if len(value) == 0: + metadata_list.append((prefix + key, value)) + inner_metadata_list, inner_tensor_list = _split_tensor_dict( + value, prefix + key + "%" + ) + metadata_list.extend(inner_metadata_list) + tensor_list.extend(inner_tensor_list) + else: + metadata_list.append((prefix + key, value)) + return metadata_list, tensor_list + + +def _update_nested_dict(nested_dict, flattened_key, value): + key_splits = flattened_key.split("%") + cur_dict = nested_dict + for k in key_splits[:-1]: + if k not in cur_dict: + cur_dict[k] = {} + cur_dict = cur_dict[k] + cur_dict[key_splits[-1]] = value + + +class GroupCoordinator: + """ + PyTorch ProcessGroup wrapper for a group of processes. + PyTorch ProcessGroup is bound to one specific communication backend, + e.g. NCCL, Gloo, MPI, etc. + GroupCoordinator takes charge of all the communication operations among + the processes in the group. It can route the communication to + a specific implementation (e.g. switch allreduce implementation + based on the tensor size and npu graph mode). + """ + + # available attributes: + rank: int # global rank + ranks: List[int] # global ranks in the group + world_size: int # size of the group + # difference between `local_rank` and `rank_in_group`: + # if we have a group of size 4 across two nodes: + # Process | Node | Rank | Local Rank | Rank in Group + # 0 | 0 | 0 | 0 | 0 + # 1 | 0 | 1 | 1 | 1 + # 2 | 1 | 2 | 0 | 2 + # 3 | 1 | 3 | 1 | 3 + local_rank: int # local rank used to assign devices + rank_in_group: int # rank inside the group + cpu_group: ProcessGroup # group for CPU communication + device_group: ProcessGroup # group for device communication + + def __init__( + self, + group_ranks: List[List[int]], + local_rank: int, + torch_distributed_backend: Union[str, Backend], + ): + + self.rank = torch.distributed.get_rank() + self.local_rank = local_rank + self.device_group = None + self.cpu_group = None + + for ranks in group_ranks: + device_group = torch.distributed.new_group( + ranks, backend=torch_distributed_backend + ) + # a group with `gloo` backend, to allow direct coordination between + # processes through the CPU. + cpu_group = torch.distributed.new_group(ranks, backend="gloo") + if self.rank in ranks: + self.ranks = ranks + self.world_size = len(ranks) + self.rank_in_group = ranks.index(self.rank) + self.device_group = device_group + self.cpu_group = cpu_group + + if torch.npu.is_available(): + self.device = torch.device(f"npu:{local_rank}") + else: + self.device = torch.device("cpu") + + @property + def first_rank(self): + """Return the global rank of the first process in the group""" + return self.ranks[0] + + @property + def last_rank(self): + """Return the global rank of the last process in the group""" + return self.ranks[-1] + + @property + def is_first_rank(self): + """Return whether the caller is the first process in the group""" + return self.rank == self.first_rank + + @property + def is_last_rank(self): + """Return whether the caller is the last process in the group""" + return self.rank == self.last_rank + + @property + def next_rank(self): + """Return the global rank of the process that follows the caller""" + rank_in_group = self.rank_in_group + world_size = self.world_size + return self.ranks[(rank_in_group + 1) % world_size] + + @property + def prev_rank(self): + """Return the global rank of the process that precedes the caller""" + rank_in_group = self.rank_in_group + world_size = self.world_size + return self.ranks[(rank_in_group - 1) % world_size] + + @property + def group_next_rank(self): + """Return the group rank of the process that follows the caller""" + rank_in_group = self.rank_in_group + world_size = self.world_size + return (rank_in_group + 1) % world_size + + @property + def group_prev_rank(self): + """Return the group rank of the process that precedes the caller""" + rank_in_group = self.rank_in_group + world_size = self.world_size + return (rank_in_group - 1) % world_size + + @property + def skip_rank(self): + """Return the global rank of the process that skip connects with the caller""" + rank_in_group = self.rank_in_group + world_size = self.world_size + return self.ranks[(world_size - rank_in_group - 1) % world_size] + + @property + def group_skip_rank(self): + """Return the group rank of the process that skip connects with the caller""" + rank_in_group = self.rank_in_group + world_size = self.world_size + return (world_size - rank_in_group - 1) % world_size + + def all_reduce(self, input_: torch.Tensor) -> torch.Tensor: + """ + NOTE: This operation will be applied in-place or out-of-place. + Always assume this function modifies its input, but use the return + value as the output. + """ + # Bypass the function if we are using only 1 GPU. + if self.world_size == 1: + return input_ + else: + torch.distributed.all_reduce(input_, group=self.device_group) + return input_ + + def all_gather( + self, input_: torch.Tensor, dim: int = 0, separate_tensors: bool = False + ) -> Union[torch.Tensor, List[torch.Tensor]]: + world_size = self.world_size + # Bypass the function if we are using only 1 GPU. + if world_size == 1: + return input_ + if dim < 0: + # Convert negative dim to positive. + dim += input_.dim() + # Allocate output tensor. + input_size = list(input_.size()) + input_size[0] *= world_size + output_tensor = torch.empty( + input_size, dtype=input_.dtype, device=input_.device + ) + # All-gather. + torch.distributed.all_gather_into_tensor( + output_tensor, input_, group=self.device_group + ) + if dim != 0: + input_size[0] //= world_size + output_tensor = output_tensor.reshape([world_size, ] + input_size) + output_tensor = output_tensor.movedim(0, dim) + + if separate_tensors: + tensor_list = [ + output_tensor.view(-1) + .narrow(0, input_.numel() * i, input_.numel()) + .view_as(input_) + for i in range(world_size) + ] + return tensor_list + else: + input_size = list(input_.size()) + input_size[dim] = input_size[dim] * world_size + # Reshape + output_tensor = output_tensor.reshape(input_size) + return output_tensor + + def gather(self, input_: torch.Tensor, dst: int = 0, dim: int = -1) -> torch.Tensor: + """ + NOTE: We assume that the input tensor is on the same device across + all the ranks. + NOTE: `dst` is the local rank of the destination rank. + """ + world_size = self.world_size + # Bypass the function if we are using only 1 GPU. + if world_size == 1: + return input_ + if dim < 0: + # Convert negative dim to positive. + dim += input_.dim() + # Allocate output tensor. + if self.rank_in_group == dst: + gather_list = [torch.empty_like(input_) for _ in range(world_size)] + else: + gather_list = None + # Gather. + torch.distributed.gather( + input_, gather_list, dst=self.ranks[dst], group=self.device_group + ) + if self.rank_in_group == dst: + output_tensor = torch.cat(gather_list, dim=dim) + else: + output_tensor = None + return output_tensor + + def broadcast(self, input_: torch.Tensor, src: int = 0): + """Broadcast the input tensor. + NOTE: `src` is the local rank of the source rank. + """ + + # Bypass the function if we are using only 1 GPU. + if self.world_size == 1: + return input_ + # Broadcast. + torch.distributed.broadcast( + input_, src=self.ranks[src], group=self.device_group + ) + return input_ + + def broadcast_object(self, obj: Optional[Any] = None, src: int = 0): + """Broadcast the input object. + NOTE: `src` is the local rank of the source rank. + """ + + # Bypass the function if we are using only 1 GPU. + if self.world_size == 1: + return obj + if self.shm_broadcaster is not None: + return self.shm_broadcaster.broadcast_object(obj) + if self.rank_in_group == src: + torch.distributed.broadcast_object_list( + [obj], src=self.ranks[src], group=self.cpu_group + ) + return obj + else: + recv = [None] + torch.distributed.broadcast_object_list( + recv, src=self.ranks[src], group=self.cpu_group + ) + return recv[0] + + def broadcast_object_list( + self, obj_list: List[Any], src: int = 0, group: Optional[ProcessGroup] = None + ): + """Broadcast the input object list. + NOTE: `src` is the local rank of the source rank. + """ + + # Bypass the function if we are using only 1 GPU. + if self.world_size == 1: + return obj_list + # Broadcast. + torch.distributed.broadcast_object_list( + obj_list, src=self.ranks[src], group=self.device_group + ) + return obj_list + + def send_object(self, obj: Any, dst: int) -> None: + """Send the input object list to the destination rank.""" + """NOTE: `dst` is the local rank of the destination rank.""" + + # Serialize object to tensor and get the size as well + object_tensor = torch.frombuffer(pickle.dumps(obj), dtype=torch.uint8) + + size_tensor = torch.tensor( + [object_tensor.numel()], dtype=torch.long, device="cpu" + ) + + # Send object size + + torch.distributed.send(size_tensor, dst=self.ranks[dst], group=self.cpu_group) + + # Send object + torch.distributed.send(object_tensor, dst=self.ranks[dst], group=self.cpu_group) + + return None + + def recv_object(self, src: int) -> Any: + """Receive the input object list from the source rank.""" + """NOTE: `src` is the local rank of the source rank.""" + + size_tensor = torch.empty(1, dtype=torch.long, device="cpu") + + # Receive object size + rank_size = torch.distributed.recv( + size_tensor, src=self.ranks[src], group=self.cpu_group + ) + + # Tensor to receive serialized objects into. + object_tensor = torch.empty( # type: ignore[call-overload] + size_tensor.item(), # type: ignore[arg-type] + dtype=torch.uint8, + device="cpu", + ) + + rank_object = torch.distributed.recv( + object_tensor, src=self.ranks[src], group=self.cpu_group + ) + + obj = pickle.loads(object_tensor.numpy().tobytes()) + + return obj + + def broadcast_tensor_dict( + self, + tensor_dict: Optional[Dict[str, Union[torch.Tensor, Any]]] = None, + src: int = 0, + group: Optional[ProcessGroup] = None, + metadata_group: Optional[ProcessGroup] = None, + ) -> Optional[Dict[str, Union[torch.Tensor, Any]]]: + """Broadcast the input tensor dictionary. + NOTE: `src` is the local rank of the source rank. + """ + # Bypass the function if we are using only 1 GPU. + if not torch.distributed.is_initialized() or self.world_size == 1: + return tensor_dict + + group = self.device_group + metadata_group = self.cpu_group + src = self.ranks[src] + + rank = self.rank + if rank == src: + metadata_list: List[Tuple[Any, Any]] = [] + metadata_list, tensor_list = _split_tensor_dict(tensor_dict) + # `metadata_list` lives in CPU memory. + # `broadcast_object_list` has serialization & deserialization, + # all happening on CPU. Therefore, we can use the CPU group. + self.broadcast_object(metadata_list, src=src) + async_handles = [] + for tensor in tensor_list: + if tensor.numel() == 0: + # Skip broadcasting empty tensors. + continue + if tensor.is_cpu: + # use metadata_group for CPU tensors + handle = torch.distributed.broadcast( + tensor, src=src, group=metadata_group, async_op=True + ) + else: + # use group for GPU tensors + handle = torch.distributed.broadcast( + tensor, src=src, group=group, async_op=True + ) + async_handles.append(handle) + for async_handle in async_handles: + async_handle.wait() + + else: + metadata_list = self.broadcast_object(None, src=src) + tensor_dict = {} + async_handles = [] + for key, value in metadata_list: + if isinstance(value, TensorMetadata): + tensor = torch.empty( + value.size, dtype=value.dtype, device=value.device + ) + if tensor.numel() == 0: + # Skip broadcasting empty tensors. + _update_nested_dict(tensor_dict, key, tensor) + continue + if tensor.is_cpu: + # use metadata_group for CPU tensors + handle = torch.distributed.broadcast( + tensor, src=src, group=metadata_group, async_op=True + ) + else: + # use group for GPU tensors + handle = torch.distributed.broadcast( + tensor, src=src, group=group, async_op=True + ) + async_handles.append(handle) + _update_nested_dict(tensor_dict, key, tensor) + else: + _update_nested_dict(tensor_dict, key, value) + for async_handle in async_handles: + async_handle.wait() + return tensor_dict + + def send_tensor_dict( + self, + tensor_dict: Dict[str, Union[torch.Tensor, Any]], + dst: Optional[int] = None, + ) -> Optional[Dict[str, Union[torch.Tensor, Any]]]: + """Send the input tensor dictionary. + NOTE: `dst` is the local rank of the source rank. + """ + # Bypass the function if we are using only 1 GPU. + if not torch.distributed.is_initialized() or self.world_size == 1: + return tensor_dict + + group = self.device_group + metadata_group = self.cpu_group + + if dst is None: + dst = self.group_next_rank + + metadata_list: List[Tuple[Any, Any]] = [] + metadata_list, tensor_list = _split_tensor_dict(tensor_dict) + # `metadata_list` lives in CPU memory. + # `send_object_list` has serialization & deserialization, + # all happening on CPU. Therefore, we can use the CPU group. + self.send_object(metadata_list, dst=dst) + for tensor in tensor_list: + if tensor.numel() == 0: + # Skip sending empty tensors. + continue + if tensor.is_cpu: + # use metadata_group for CPU tensors + torch.distributed.send( + tensor, dst=self.ranks[dst], group=metadata_group + ) + else: + # use group for GPU tensors + torch.distributed.send(tensor, dst=self.ranks[dst], group=group) + return None + + def recv_tensor_dict( + self, src: Optional[int] = None + ) -> Optional[Dict[str, Union[torch.Tensor, Any]]]: + """Recv the input tensor dictionary. + NOTE: `src` is the local rank of the source rank. + """ + # Bypass the function if we are using only 1 GPU. + if not torch.distributed.is_initialized() or self.world_size == 1: + return None + + group = self.device_group + metadata_group = self.cpu_group + + if src is None: + src = self.group_prev_rank + + recv_metadata_list = self.recv_object(src=src) + tensor_dict: Dict[str, Any] = {} + for key, value in recv_metadata_list: + if isinstance(value, TensorMetadata): + tensor = torch.empty(value.size, dtype=value.dtype, device=value.device) + if tensor.numel() == 0: + # Skip broadcasting empty tensors. + _update_nested_dict(tensor_dict, key, tensor) + continue + if tensor.is_cpu: + # use metadata_group for CPU tensors + torch.distributed.recv( + tensor, src=self.ranks[src], group=metadata_group + ) + else: + # use group for GPU tensors + torch.distributed.recv(tensor, src=self.ranks[src], group=group) + _update_nested_dict(tensor_dict, key, tensor) + else: + _update_nested_dict(tensor_dict, key, value) + return tensor_dict + + def barrier(self): + """Barrier synchronization among the group. + NOTE: don't use `device_group` here! `barrier` in NCCL is + terrible because it is internally a broadcast operation with + secretly created GPU tensors. It is easy to mess up the current + device. Use the CPU group instead. + """ + torch.distributed.barrier(group=self.cpu_group) + + def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None: + """Sends a tensor to the destination rank in a non-blocking way""" + """NOTE: `dst` is the rank_in_group of the destination rank.""" + if dst is None: + dst = self.group_next_rank + + torch.distributed.send( + tensor, + self.ranks[dst], + group=( + self.device_groups[self.rank_in_group % 2] + if self.world_size == 2 + else self.device_group + ), + ) + + def recv( + self, size: torch.Size, dtype: torch.dtype, src: Optional[int] = None + ) -> torch.Tensor: + """Receives a tensor from the src rank.""" + """NOTE: `src` is the rank_in_group of the source rank.""" + if src is None: + src = self.group_prev_rank + + tensor = torch.empty(size, dtype=dtype, device=self.device) + torch.distributed.recv( + tensor, + self.ranks[src], + ( + self.device_groups[(self.rank_in_group + 1) % 2] + if self.world_size == 2 + else self.device_group + ), + ) + return tensor + + def destroy(self): + if self.device_group is not None: + torch.distributed.destroy_process_group(self.device_group) + self.device_group = None + if self.cpu_group is not None: + torch.distributed.destroy_process_group(self.cpu_group) + self.cpu_group = None + + +class SequenceParallelGroupCoordinator(GroupCoordinator): + def __init__( + self, + group_ranks: List[List[int]], + local_rank: int, + torch_distributed_backend: Union[str, Backend], + **kwargs, + ): + super().__init__( + group_ranks=group_ranks, + local_rank=local_rank, + torch_distributed_backend=torch_distributed_backend, + ) + self.ulysses_group = kwargs.get("ulysses_group", None) + self.ulysses_world_size = torch.distributed.get_world_size(self.ulysses_group) + self.ulysses_rank = torch.distributed.get_rank(self.ulysses_group) + + self.ring_group = kwargs.get("ring_group", None) + self.ring_world_size = torch.distributed.get_world_size(self.ring_group) + self.ring_rank = torch.distributed.get_rank(self.ring_group) \ No newline at end of file diff --git a/diffsynth/npu_utils/distributed/parallel_mgr.py b/diffsynth/npu_utils/distributed/parallel_mgr.py new file mode 100644 index 0000000000000000000000000000000000000000..01ff65b466d39c41c4fc9624c920ea9cd7db0481 --- /dev/null +++ b/diffsynth/npu_utils/distributed/parallel_mgr.py @@ -0,0 +1,342 @@ +import os +from typing import List, Optional +from dataclasses import dataclass +import torch.distributed as dist +import torch_npu +import logging +from .utils import RankGenerator, generate_masked_orthogonal_rank_groups +from .group_coordinator import GroupCoordinator, SequenceParallelGroupCoordinator + +from yunchang import set_seq_parallel_pg +from yunchang.globals import PROCESS_GROUP + +_WORLD: Optional[GroupCoordinator] = None +_TP: Optional[GroupCoordinator] = None +_SP: Optional[SequenceParallelGroupCoordinator] = None +_CFG: Optional[GroupCoordinator] = None + + +@dataclass +class ParallelConfig: + tp_degree: int = 1 + sp_degree: int = 1 + ulysses_degree: int = 1 + ring_degree: int = 1 + use_cfg_parallel: bool = False + world_size: int = 1 + + def __post_init__(self): + if self.use_cfg_parallel: + self.cfg_degree = 2 + else: + self.cfg_degree = 1 + if not self.tp_degree * self.sp_degree * self.cfg_degree <= self.world_size: + logging.error( + "tp_degree * sp_degree * cfg_degree must be less than or equal to " + "world_size because of classifier free guidance" + ) + if not (self.world_size % (self.tp_degree * self.sp_degree * self.cfg_degree) == 0): + logging.error("world_size must be divisible by tp_degree * sp_degree * cfg_degree") + + +# * QUERY +def get_world_group() -> GroupCoordinator: + if _WORLD is None: + logging.error("world group is not initialized") + return _WORLD + + +# TP +def get_tp_group() -> GroupCoordinator: + assert _TP is not None, "tensor model parallel group is not initialized" + return _TP + + +def get_tensor_model_parallel_world_size(): + """Return world size for the tensor model parallel group.""" + return get_tp_group().world_size + + +def get_tensor_model_parallel_rank(): + """Return my rank for the tensor model parallel group.""" + return get_tp_group().rank_in_group + + +# SP +def get_sp_group() -> SequenceParallelGroupCoordinator: + if _SP is None: + logging.error("pipeline model parallel group is not initialized") + return _SP + + +def get_sequence_parallel_state(): + """Return state for the sequence parallel group.""" + return _SP is not None + + +def get_sequence_parallel_world_size(): + """Return world size for the sequence parallel group.""" + if not get_sequence_parallel_state(): + return 1 + return get_sp_group().world_size + + +def get_sequence_parallel_rank(): + """Return my rank for the sequence parallel group.""" + if not get_sequence_parallel_state(): + return 0 + return get_sp_group().rank_in_group + + +# CFG +def get_cfg_group() -> GroupCoordinator: + if _CFG is None: + logging.error("classifier_free_guidance parallel group is not initialized") + return _CFG + + +def get_cfg_state(): + """Return state for the sequence parallel group.""" + return _CFG is not None + + +def get_classifier_free_guidance_world_size(): + """Return world size for the classifier_free_guidance parallel group.""" + if not get_cfg_state(): + return 1 + return get_cfg_group().world_size + + +def get_classifier_free_guidance_rank(): + """Return my rank for the classifier_free_guidance parallel group.""" + if not get_cfg_state(): + return 0 + return get_cfg_group().rank_in_group + + +def init_world_group( + ranks: List[int], local_rank: int, backend: str +) -> GroupCoordinator: + return GroupCoordinator( + group_ranks=[ranks], + local_rank=local_rank, + torch_distributed_backend=backend, + ) + + +def init_distributed_environment( + world_size: int = -1, + rank: int = -1, + distributed_init_method: str = "env://", + local_rank: int = -1, + backend: str = "hccl", +): + logging.debug( + "world_size=%d rank=%d local_rank=%d " "distributed_init_method=%s backend=%s", + world_size, + rank, + local_rank, + distributed_init_method, + backend, + ) + if not dist.is_initialized(): + if distributed_init_method is None: + logging.error( + "distributed_init_method must be provided when initializing " + "distributed environment" + ) + # this backend is used for WORLD + dist.init_process_group( + backend=backend, + init_method=distributed_init_method, + world_size=world_size, + rank=rank, + ) + # set the local rank + # local_rank is not available in torch ProcessGroup, + # see https://github.com/pytorch/pytorch/issues/122816 + if local_rank == -1: + # local rank not set, this usually happens in single-node + # setting, where we can use rank as local rank + if distributed_init_method == "env://": + local_rank = int(os.getenv('LOCAL_RANK', 0)) + torch_npu.npu.set_device(local_rank) + else: + local_rank = rank + global _WORLD + if _WORLD is None: + ranks = list(range(dist.get_world_size())) + _WORLD = init_world_group(ranks, local_rank, backend) + else: + if not _WORLD.world_size == dist.get_world_size(): + logging.error("world group already initialized with a different world size") + + +def model_parallel_is_initialized(): + """Check if tensor and pipeline parallel groups are initialized.""" + return ( + _CFG is not None + and _SP is not None + and _TP is not None + ) + + +def init_model_parallel_group( + group_ranks: List[List[int]], + local_rank: int, + backend: str, + parallel_mode: str, + **kwargs, +) -> GroupCoordinator: + if parallel_mode not in [ + "tensor", + "sequence", + "classifier_free_guidance", + ]: + logging.error(f"parallel_mode {parallel_mode} is not supported") + if parallel_mode == "sequence": + return SequenceParallelGroupCoordinator( + group_ranks=group_ranks, + local_rank=local_rank, + torch_distributed_backend=backend, + **kwargs, + ) + else: + return GroupCoordinator( + group_ranks=group_ranks, + local_rank=local_rank, + torch_distributed_backend=backend, + ) + + +def initialize_model_parallel( + classifier_free_guidance_degree: int = 1, + sequence_parallel_degree: int = 1, + ulysses_degree: int = 1, + ring_degree: int = 1, + tensor_parallel_degree: int = 1, + backend: Optional[str] = None, +) -> None: + """ + Initialize model parallel groups. + + Arguments: + classifier_free_guidance_degree: number of GPUs used for Classifier Free Guidance (CFG) + sequence_parallel_degree: number of GPUs used for sequence parallelism. + tensor_parallel_degree: number of GPUs used for tensor parallelism. + backend: distributed backend of pytorch collective comm. + """ + # Get world size and rank. Ensure some consistencies. + if not dist.is_initialized(): + logging.error("dist is not initialized") + world_size: int = dist.get_world_size() + backend = backend + + if ( + world_size + != classifier_free_guidance_degree + * sequence_parallel_degree + * tensor_parallel_degree + ): + raise RuntimeError( + f"world_size ({world_size}) is not equal to " + f"sequence_parallel_degree ({sequence_parallel_degree}) x " + f"classifier_free_guidance_degree " + f"({classifier_free_guidance_degree}) x " + f"tensor_parallel_degree " + f"({tensor_parallel_degree})" + ) + + rank_generator: RankGenerator = RankGenerator( + tensor_parallel_degree, + sequence_parallel_degree, + classifier_free_guidance_degree, + "tp-sp-cfg", + ) + + global _CFG + if _CFG is not None: + logging.error("classifier_free_guidance group is already initialized") + _CFG = init_model_parallel_group( + group_ranks=rank_generator.get_ranks("cfg"), + local_rank=get_world_group().local_rank, + backend=backend, + parallel_mode="classifier_free_guidance", + ) + + global _SP + if _SP is not None: + logging.error("sequence parallel group is already initialized") + set_seq_parallel_pg( + sp_ulysses_degree=ulysses_degree, + sp_ring_degree=ring_degree, + rank=get_world_group().rank_in_group, + world_size=world_size + ) + _SP = init_model_parallel_group( + group_ranks=rank_generator.get_ranks("sp"), + local_rank=get_world_group().local_rank, + backend=backend, + parallel_mode="sequence", + ulysses_group=PROCESS_GROUP.ULYSSES_PG, + ring_group=PROCESS_GROUP.RING_PG, + ) + + global _TP + assert _TP is None, "Tensor parallel group is already initialized" + _TP = init_model_parallel_group( + group_ranks=rank_generator.get_ranks("tp"), + local_rank=get_world_group().local_rank, + backend=backend, + parallel_mode="tensor", + ) + + +def destroy_model_parallel(): + """Set the groups to none and destroy them.""" + global _CFG + if _CFG: + _CFG.destroy() + _CFG = None + + global _SP + if _SP: + _SP.destroy() + _SP = None + + global _TP + if _TP: + _TP.destroy() + _TP = None + + +def destroy_distributed_environment(): + global _WORLD + if _WORLD: + _WORLD.destroy() + _WORLD = None + if dist.is_initialized(): + dist.destroy_process_group() + + +def init_parallel_env(parallel_config: ParallelConfig): + if not model_parallel_is_initialized(): + logging.warning("Model parallel is not initialized, initializing...") + init_distributed_environment( + world_size=dist.get_world_size(), + rank=dist.get_rank(), + backend='hccl', + ) + initialize_model_parallel( + classifier_free_guidance_degree=parallel_config.cfg_degree, + sequence_parallel_degree=parallel_config.sp_degree, + ulysses_degree=parallel_config.ulysses_degree, + ring_degree=parallel_config.ring_degree, + tensor_parallel_degree=parallel_config.tp_degree, + ) + + +def finalize_parallel_env(): + if model_parallel_is_initialized(): + destroy_model_parallel() + destroy_distributed_environment() \ No newline at end of file diff --git a/diffsynth/npu_utils/distributed/tp_applicator.py b/diffsynth/npu_utils/distributed/tp_applicator.py new file mode 100644 index 0000000000000000000000000000000000000000..eb65c823f096fda7892db8d8455ac0c8235133be --- /dev/null +++ b/diffsynth/npu_utils/distributed/tp_applicator.py @@ -0,0 +1,329 @@ +import torch +import torch.nn as nn +import torch_npu + +from ..modules.model import WanSelfAttention, WanAttentionBlock, WanRMSNorm +from .parallel_mgr import ( + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + get_tp_group, +) +from .group_coordinator import GroupCoordinator + + +class TensorParallelApplicator: + def __init__(self, tp_size, device_map="cpu", tp_group=None): + self.tp_size = tp_size + self.tp_rank = get_tensor_model_parallel_rank() + self.tp_group = tp_group or get_tp_group() + self.device_map = device_map + + def apply_to_model(self, model): + self._apply_tp_to_attention(model) + self._apply_tp_to_ffn(model) + + def _apply_tp_to_attention(self, module): + for name, child in module.named_children(): + if isinstance(child, WanSelfAttention): + self._replace_self_attention(child) + else: + self._apply_tp_to_attention(child) + + def _replace_self_attention(self, child): + child.dim = child.dim // self.tp_size + child.num_heads = child.num_heads // self.tp_size + orig_q = child.q + orig_k = child.k + orig_v = child.v + orig_o = child.o + orig_dtype = orig_q.weight.dtype + + column_out = orig_q.out_features // self.tp_size + row_in = orig_o.in_features // self.tp_size + + child.q = ColumnParallelLinear( + orig_q.in_features, + column_out, + bias=orig_q.bias is not None, + gather_output=False, + tp_size=self.tp_size, + tp_rank=self.tp_rank, + tp_group=self.tp_group + ).to(dtype=orig_dtype).to(self.device_map) + + child.k = ColumnParallelLinear( + orig_k.in_features, + column_out, + bias=orig_k.bias is not None, + gather_output=False, + tp_size=self.tp_size, + tp_rank=self.tp_rank, + tp_group=self.tp_group + ).to(dtype=orig_dtype).to(self.device_map) + + child.v = ColumnParallelLinear( + orig_v.in_features, + column_out, + bias=orig_v.bias is not None, + gather_output=False, + tp_size=self.tp_size, + tp_rank=self.tp_rank, + tp_group=self.tp_group + ).to(dtype=orig_dtype).to(self.device_map) + + child.o = RowParallelLinear( + row_in, + orig_o.out_features, + bias=orig_o.bias is not None, + input_is_parallel=True, + tp_size=self.tp_size, + tp_rank=self.tp_rank, + tp_group=self.tp_group + ).to(dtype=orig_dtype).to(self.device_map) + + self._split_self_weights(child, orig_q, orig_k, orig_v, orig_o) + + if isinstance(child.norm_q, WanRMSNorm): + ori_norm_q = child.norm_q + child.norm_q = TensorParallelRMSNorm( + dim=child.norm_q.dim, + tp_size=self.tp_size, + tp_group=self.tp_group + ) + self._split_norm_weights(child.norm_q, ori_norm_q) + + if isinstance(child.norm_k, WanRMSNorm): + ori_norm_k = child.norm_k + child.norm_k = TensorParallelRMSNorm( + dim=child.norm_k.dim, + tp_size=self.tp_size, + tp_group=self.tp_group + ) + self._split_norm_weights(child.norm_k, ori_norm_k) + + + def _split_self_weights(self, new_layer, orig_q, orig_k, orig_v, orig_o): + q_chunk = torch.chunk(orig_q.weight.data, self.tp_size, dim=0)[self.tp_rank] + new_layer.q.weight.data = q_chunk.contiguous() + + k_chunk = torch.chunk(orig_k.weight.data, self.tp_size, dim=0)[self.tp_rank] + new_layer.k.weight.data = k_chunk.contiguous() + + v_chunk = torch.chunk(orig_v.weight.data, self.tp_size, dim=0)[self.tp_rank] + new_layer.v.weight.data = v_chunk.contiguous() + + o_chunk = torch.chunk(orig_o.weight.data, self.tp_size, dim=1)[self.tp_rank] + new_layer.o.weight.data = o_chunk.contiguous() + + if orig_q.bias is not None: + bias_chunk = torch.chunk(orig_q.bias.data, self.tp_size, dim=0)[self.tp_rank] + new_layer.q.bias.data = bias_chunk.contiguous() + if orig_k.bias is not None: + bias_chunk = torch.chunk(orig_k.bias.data, self.tp_size, dim=0)[self.tp_rank] + new_layer.k.bias.data = bias_chunk.contiguous() + if orig_v.bias is not None: + bias_chunk = torch.chunk(orig_v.bias.data, self.tp_size, dim=0)[self.tp_rank] + new_layer.v.bias.data = bias_chunk.contiguous() + if orig_o.bias is not None: + new_layer.o.bias.data = orig_o.bias.data.clone() / self.tp_size + + def _split_norm_weights(self, new_layer, norm): + norm_chunk = torch.chunk(norm.weight.data, self.tp_size, dim=0)[self.tp_rank] + new_layer.weight.data = norm_chunk.contiguous() + + def _replace_cross_attention(self, child): + orig_wq = child.wq + orig_wkv = child.wkv + orig_wo = child.wo + orig_dtype = orig_wq.weight.dtype + + column_out_wq = orig_wq.out_features // self.tp_size + column_out_wkv = orig_wkv.out_features // self.tp_size + row_in_wo = orig_wo.in_features // self.tp_size + + child.wq = ColumnParallelLinear( + orig_wq.in_features, + column_out_wq, + bias=orig_wq.bias is not None, + gather_output=False, + tp_size=self.tp_size, + tp_rank=self.tp_rank, + tp_group=self.tp_group + ).to(dtype=orig_dtype).to(self.device_map) + + child.wkv = ColumnParallelLinear( + orig_wkv.in_features, + column_out_wkv, + bias=orig_wkv.bias is not None, + gather_output=False, + tp_size=self.tp_size, + tp_rank=self.tp_rank, + tp_group=self.tp_group + ).to(dtype=orig_dtype).to(self.device_map) + + child.wo = RowParallelLinear( + row_in_wo, + orig_wo.out_features, + bias=orig_wo.bias is not None, + input_is_parallel=True, + tp_size=self.tp_size, + tp_rank=self.tp_rank, + tp_group=self.tp_group + ).to(dtype=orig_dtype).to(self.device_map) + + self._split_cross_attention_weights(child, orig_wq, orig_wkv, orig_wo) + child.n_heads_per_tp = child.n_heads // self.tp_size + + def _split_cross_attention_weights(self, new_layer, orig_wq, orig_wkv, orig_wo): + wq_chunk = torch.chunk(orig_wq.weight.data, self.tp_size, dim=0)[self.tp_rank] + new_layer.wq.weight.data = wq_chunk.contiguous() + if orig_wq.bias is not None: + wq_bias_chunk = torch.chunk(orig_wq.bias.data, self.tp_size, dim=0)[self.tp_rank] + new_layer.wq.bias.data = wq_bias_chunk.contiguous() + + wkv_chunk = torch.chunk(orig_wkv.weight.data, self.tp_size, dim=0)[self.tp_rank] + new_layer.wkv.weight.data = wkv_chunk.contiguous() + if orig_wkv.bias is not None: + wkv_bias_chunk = torch.chunk(orig_wkv.bias.data, self.tp_size, dim=0)[self.tp_rank] + new_layer.wkv.bias.data = wkv_bias_chunk.contiguous() + + wo_chunk = torch.chunk(orig_wo.weight.data, self.tp_size, dim=1)[self.tp_rank] + new_layer.wo.weight.data = wo_chunk.contiguous() + if orig_wo.bias is not None: + new_layer.wo.bias.data = orig_wo.bias.data.clone() / self.tp_size + + def _apply_tp_to_ffn(self, module): + for name, child in module.named_children(): + if isinstance(child, WanAttentionBlock): + self._replace_ffn_layers(child) + else: + self._apply_tp_to_ffn(child) + + def _replace_ffn_layers(self, block): + ff_layer = block.ffn + orig_gelu_linear = ff_layer[0] + inner_dim_per_tp = orig_gelu_linear.out_features // self.tp_size + orig_dtype = orig_gelu_linear.weight.dtype + + ff_layer[0] = ColumnParallelLinear( + in_features=orig_gelu_linear.in_features, + out_features=inner_dim_per_tp, + bias=orig_gelu_linear.bias is not None, + gather_output=False, + tp_size=self.tp_size, + tp_rank=self.tp_rank, + tp_group=self.tp_group + ).to(dtype=orig_dtype).to(self.device_map) + + orig_output_linear = ff_layer[2] + ff_layer[2] = RowParallelLinear( + in_features=inner_dim_per_tp, + out_features=orig_output_linear.out_features, + bias=orig_output_linear.bias is not None, + input_is_parallel=True, + tp_size=self.tp_size, + tp_rank=self.tp_rank, + tp_group=self.tp_group + ).to(dtype=orig_dtype).to(self.device_map) + + self._split_ffn_weights(ff_layer, orig_gelu_linear, orig_output_linear) + + def _split_ffn_weights(self, new_ffn, orig_first_linear, orig_second_linear): + with torch.no_grad(): + first_weight_chunk = torch.chunk(orig_first_linear.weight.data, self.tp_size, dim=0)[self.tp_rank] + new_ffn[0].weight.data.copy_(first_weight_chunk.contiguous()) + + if orig_first_linear.bias is not None: + first_bias_chunk = torch.chunk(orig_first_linear.bias.data, self.tp_size, dim=0)[self.tp_rank] + new_ffn[0].bias.data.copy_(first_bias_chunk.contiguous()) + + second_weight_chunk = torch.chunk(orig_second_linear.weight.data, self.tp_size, dim=1)[self.tp_rank] + new_ffn[2].weight.data.copy_(second_weight_chunk.contiguous()) + + if orig_second_linear.bias is not None: + new_ffn[2].bias.data.copy_(orig_second_linear.bias.data.clone() / self.tp_size) + + +class ColumnParallelLinear(nn.Linear): + def __init__(self, in_features, out_features, bias=True, gather_output=True, tp_size=None, tp_rank=None, tp_group=None): + self.tp_size = tp_size or get_tensor_model_parallel_world_size() + self.tp_rank = tp_rank or get_tensor_model_parallel_rank() + self.tp_group = tp_group or get_tp_group() + + super().__init__(in_features, out_features, bias=bias) + + def forward(self, x): + x = super().forward(x) + return x + + +class RowParallelLinear(nn.Linear): + def __init__(self, in_features, out_features, bias=True, input_is_parallel=True, + tp_size=None, tp_rank=None, tp_group=None, matmul_allreduce_type="torch"): + self.tp_size = tp_size or get_tensor_model_parallel_world_size() + self.tp_rank = tp_rank or get_tensor_model_parallel_rank() + self.tp_group = tp_group or get_tp_group() + self.input_is_parallel = input_is_parallel + + if matmul_allreduce_type == "atb": + try: + from atb_ops.ops.matmul_allreduce import matmul_allreduce + self.matmul_allreduce = matmul_allreduce + self.matmul_allreduce_type = "atb" + except Exception: + self.matmul_allreduce = None + self.matmul_allreduce_type = "torch" + else: + self.matmul_allreduce_type = matmul_allreduce_type + + super().__init__(in_features, out_features, bias=bias) + + def forward(self, x): + if not self.input_is_parallel: + x = torch.chunk(x, self.tp_size, dim=-1)[self.tp_rank] + + if self.matmul_allreduce_type == "atb": + if x.dim() == 2: + output = torch.empty((x.shape[0], self.weight.shape[0]), dtype=x.dtype, device=x.device) + elif x.dim() == 3: + b, s, hx = x.size() + output = torch.empty((b, s, self.weight.shape[0]), dtype=x.dtype, device=x.device) + self.matmul_allreduce(output, x, self.weight) + elif self.matmul_allreduce_type == "torch_npu": + if isinstance(self.tp_group, GroupCoordinator): + tp_pg = self.tp_group.device_group + else: + tp_pg = self.tp_group + hcom = tp_pg._get_backend(torch.device('npu')).get_hccl_comm_name + output = torch_npu.npu_mm_all_reduce_base(x, self.weight, hcom) + else: + x = super().forward(x) + # 执行All-Reduce聚合结果 + if isinstance(self.tp_group, GroupCoordinator): + output = self.tp_group.all_reduce(x) + else: + torch.distributed.all_reduce(x, group=self.tp_group) + output = x + return output + + +class TensorParallelRMSNorm(nn.Module): + def __init__(self, dim, tp_size, tp_group, eps=1e-6): + super().__init__() + self.tp_size = tp_size + self.tp_group = tp_group + self.variance_epsilon = eps + self.weight = nn.Parameter(torch.ones(dim // self.tp_size)) + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + if isinstance(self.tp_group, GroupCoordinator): + variance = self.tp_group.all_reduce(variance) + else: + torch.distributed.all_reduce(variance, group=self.tp_group) + variance /= self.tp_size + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + + return self.weight * hidden_states.to(input_dtype) \ No newline at end of file diff --git a/diffsynth/npu_utils/distributed/utils.py b/diffsynth/npu_utils/distributed/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..9dafb8de1e7f02baf636fa71d22d5a6117e3442a --- /dev/null +++ b/diffsynth/npu_utils/distributed/utils.py @@ -0,0 +1,152 @@ +from typing import List +import logging + + +def generate_masked_orthogonal_rank_groups( + world_size: int, parallel_size: List[int], mask: List[bool] +) -> List[List[int]]: + """Generate orthogonal parallel groups based on the parallel size and mask. + + Arguments: + world_size (int): world size + + parallel_size (List[int]): + The parallel size of each orthogonal parallel type. For example, if + tensor_parallel_size = 2, pipeline_model_parallel_group = 3, data_parallel_size = 4, + and the parallel mapping order is tp-pp-dp, then the parallel_size = [2, 3, 4]. + + mask (List[bool]): + The mask controls which parallel methods the generated groups represent. If mask[i] is + True, it means the generated group contains the i-th parallelism method. For example, + if parallel_size = [tp_size, pp_size, dp_size], and mask = [True, False , True], then + the generated group is the `tp-dp` group, if the mask = [False, True, False], then the + generated group is the `pp` group. + """ + + def prefix_product(a: List[int], init=1) -> List[int]: + r = [init] + for v in a: + init = init * v + r.append(init) + return r + + def inner_product(a: List[int], b: List[int]) -> int: + return sum([x * y for x, y in zip(a, b)]) + + def decompose(index, shape, stride=None): + """ + This function solve the math problem below: + There is an equation: + index = sum(idx[i] * stride[i]) + And given the value of index, stride. + Return the idx. + This function will used to get the pp/dp/pp_rank + from group_index and rank_in_group. + """ + if stride is None: + stride = prefix_product(shape) + idx = [(index // d) % s for s, d in zip(shape, stride)] + # stride is a prefix_product result. And the value of stride[-1] + # is not used. + if not ( + sum([x * y for x, y in zip(idx, stride[:-1])]) == index + ): + logging.error("idx {} with shape {} mismatch the return idx {}".format(index, shape, idx)) + return idx + + masked_shape = [s for s, m in zip(parallel_size, mask) if m] + unmasked_shape = [s for s, m in zip(parallel_size, mask) if not m] + + global_stride = prefix_product(parallel_size) + masked_stride = [d for d, m in zip(global_stride, mask) if m] + unmasked_stride = [d for d, m in zip(global_stride, mask) if not m] + + group_size = prefix_product(masked_shape)[-1] + num_of_group = world_size // group_size + + ranks = [] + for group_index in range(num_of_group): + # get indices from unmaksed for group_index. + decomposed_group_idx = decompose(group_index, unmasked_shape) + rank = [] + for rank_in_group in range(group_size): + # get indices from masked for rank_in_group. + decomposed_rank_idx = decompose(rank_in_group, masked_shape) + rank.append( + inner_product(decomposed_rank_idx, masked_stride) + + inner_product(decomposed_group_idx, unmasked_stride) + ) + ranks.append(rank) + return ranks + + +class RankGenerator(object): + def __init__( + self, + tp: int, + sp: int, + cfg: int, + order: str, + rank_offset: int = 0, + ) -> None: + self.tp = tp + self.sp = sp + self.cfg = cfg + self.rank_offset = rank_offset + self.world_size = tp * sp * cfg + + self.name_to_size = { + "sp": self.sp, + "cfg": self.cfg, + "tp": self.tp, + } + order = order.lower() + + for name in self.name_to_size.keys(): + if name not in order and self.name_to_size[name] != 1: + raise RuntimeError( + f"The size of ({name}) is ({self.name_to_size[name]}), but you haven't specified the order ({self.order})." + ) + elif name not in order: + order = order + "-" + name + + self.order = order + self.ordered_size = [] + + for token in order.split("-"): + self.ordered_size.append(self.name_to_size[token]) + + def get_mask(self, order: str, token: str): + ordered_token = order.split("-") + token = token.split("-") + mask = [False] * len(ordered_token) + for t in token: + mask[ordered_token.index(t)] = True + return mask + + def get_ranks(self, token): + """Get rank group by input token. + + Arguments: + token (str): + Specify the ranks type that want to get. If we want + to obtain multiple parallel types, we can use a hyphen + '-' to separate them. For example, if we want to obtain + the TP_DP group, the token should be 'tp-dp'. + + independent_ep (bool: True): + This flag controls whether we treat EP and DP independently. + EP shares ranks with DP, if we want to get ranks related to + EP, we should set the flag. For example, get_ranks('dp', True) + will get DP modulo EP group, and get_ranks('dp', False) will + get full DP group. + """ + mask = self.get_mask(self.order, token) + ranks = generate_masked_orthogonal_rank_groups( + self.world_size, self.ordered_size, mask + ) + if self.rank_offset > 0: + for rank_group in ranks: + for i, _ in enumerate(rank_group): + rank_group[i] += self.rank_offset + return ranks diff --git a/diffsynth/npu_utils/distributed/vae_patch_parallel.py b/diffsynth/npu_utils/distributed/vae_patch_parallel.py new file mode 100644 index 0000000000000000000000000000000000000000..6f6664281f72fdb6f43068cc0715782af30733c8 --- /dev/null +++ b/diffsynth/npu_utils/distributed/vae_patch_parallel.py @@ -0,0 +1,737 @@ +import torch +import torch_npu +import torch.distributed as dist +import torch.nn as nn +import torch.nn.functional as F +from functools import reduce +import functools + +class Parallel_VAE_SP: + def __init__(self, h_split=1, w_split=1, all_pp_group_ranks=None, **kwargs): + """ + Initialize distributed parallel processing parameters + + Args: + h_split (int): Number of splits along height dimension + w_split (int): Number of splits along width dimension + world_size (int): Total number of processes (default: current world size) + """ + if all_pp_group_ranks is None: + all_pp_group_ranks = [list(range(0, dist.get_world_size()))] + all_pp_group_size = [ len(pp_group_ranks) for pp_group_ranks in all_pp_group_ranks] + for s in all_pp_group_size: + assert s == all_pp_group_size[0], ( f" every group size should be same") + + world_size = all_pp_group_size[0] # Get total process count [[1]][[6]] + + # Validate world_size matches grid dimensions + assert w_split * h_split == world_size, ( + f"world_size must be {w_split} * {h_split} = {w_split*h_split}, but got {world_size}" + ) + + self._creat_pp_group(all_pp_group_ranks) + # self.rank is the rank in current_pp_group + self.rank = dist.get_rank(self.current_pp_group) # Current process rank [[6]] + self.world_size = dist.get_world_size(self.current_pp_group) + self.w_split = w_split + self.h_split = h_split + + # Calculate grid coordinates + self.row_rank = self.rank // w_split # Row index (0 to w_split-1) [[6]] + self.col_rank = self.rank % w_split # Column index (0 to h_split-1) [[6]] + + # Create communication groups + self._create_group_by_row(h_split, w_split, all_pp_group_ranks) + self._create_group_by_col(h_split, w_split, all_pp_group_ranks) + self._row_col_to_global_rank() + + self.ori_conv3d = None + + # world a list of list + def _creat_pp_group(self, all_pp_group_ranks=None): + for pp_group_ranks in all_pp_group_ranks: + group = dist.new_group(ranks=pp_group_ranks) + if dist.get_rank() in pp_group_ranks: + self.current_pp_group = group + # current_pp_group_ranks is the global rank of the current_pp_group + # the reason of need it , is irend irecv need global rank + self.current_pp_group_ranks = pp_group_ranks + + + def _create_group_by_row(self, h_split, w_split, all_pp_group_ranks): + """Create process groups for row-wise communication""" + for pp_group_ranks in all_pp_group_ranks: + for r in range(h_split): + ranks_in_row = [] + for c in range(w_split): + global_rank = pp_group_ranks[r * w_split + c] + ranks_in_row.append(global_rank) + row_group = dist.new_group(ranks=ranks_in_row) + if r == self.row_rank and dist.get_rank() in pp_group_ranks: + self.row_group = row_group + + def _create_group_by_col(self, h_split, w_split, all_pp_group_ranks): + """Create process groups for column-wise communication""" + for pp_group_ranks in all_pp_group_ranks: + for c in range(self.w_split): + ranks_in_col = [] + for r in range(self.h_split): + global_rank = pp_group_ranks[r * self.w_split + c] + ranks_in_col.append(global_rank) + col_group = dist.new_group(ranks=ranks_in_col) + if c == self.col_rank and dist.get_rank() in pp_group_ranks: + self.col_group = col_group + + + def _row_col_to_global_rank(self): + # Create rank mappings for communication + self.row_to_global_rank = { + r: self.current_pp_group_ranks[ + r * self.w_split + self.col_rank + ] + for r in range(self.h_split) + } + self.col_to_global_rank = { + c: self.current_pp_group_ranks[ + self.row_rank * self.w_split + c + ] + for c in range(self.w_split) + } + + def __call__(self, x): + """Split input tensor across last two dimensions""" + x = x.chunk(self.w_split, dim=-1)[self.col_rank] + x = x.chunk(self.h_split, dim=-2)[self.row_rank] + return x + + def patch(self, x, return_lst = False): + """ + Partition input tensor into grid blocks and record partition shapes + + Args: + x (torch.Tensor): Input tensor with shape [b, c, t, h, w] + + Returns: + torch.Tensor: Local partition tensor for current process + """ + # Get input dimensions + height, width = x.shape[-2:] + + # Calculate base partition dimensions + base_patch_height = height // self.h_split + base_patch_width = width // self.w_split + remainder_height = height % self.h_split + remainder_width = width % self.w_split + + # Generate partitions + patches = [] + for r in range(self.h_split): + for c in range(self.w_split): + # Calculate current partition dimensions + patch_height = base_patch_height + (1 if r < remainder_height else 0) + patch_width = base_patch_width + (1 if c < remainder_width else 0) + + # Calculate partition boundaries + start_h = r * base_patch_height + min(r, remainder_height) + end_h = start_h + patch_height + start_w = c * base_patch_width + min(c, remainder_width) + end_w = start_w + patch_width + + # Extract partition + patch = x[..., start_h:end_h, start_w:end_w] + patches.append(patch.contiguous()) + + # Get local partition + local_patch = patches[self.rank] + + return patches if return_lst else local_patch + + def dispatch(self, local_patch): + """ + Reconstruct full tensor through two-stage all-gather + + Args: + local_patch (torch.Tensor): Local partition tensor + + Returns: + torch.Tensor: Reconstructed full tensor + """ + # First all-gather to collect partition shapes + local_shape = torch.tensor(local_patch.shape[-2:], + device=local_patch.device, dtype=torch.int32) + shape_list = [torch.empty(2, dtype=torch.int32, + device=local_patch.device) for _ in range(self.world_size)] + dist.all_gather(shape_list, local_shape, group=self.current_pp_group) + + all_shapes = [tuple(shape.tolist()) for shape in shape_list] + + # Calculate original dimensions + total_h = 0 + total_w = 0 + row_heights = {} # Track row heights + col_widths = {} # Track column widths + + for rank in range(self.world_size): + r_rank = rank // self.w_split + c_rank = rank % self.w_split + h_part, w_part = all_shapes[rank] + + # Record first occurrence of row height + if r_rank not in row_heights: + row_heights[r_rank] = h_part + # Record first occurrence of column width + if c_rank not in col_widths: + col_widths[c_rank] = w_part + + total_h = sum(row_heights.values()) + total_w = sum(col_widths.values()) + # TODO dispatch should be release to process the [B C W H] + # Prepare buffers for data gathering + batch_size, channels, time_steps = local_patch.shape[:3] + + gathered_data = [ + torch.empty( + (batch_size * channels * time_steps * h_part * w_part,), + device=local_patch.device, + dtype=local_patch.dtype + ) for h_part, w_part in all_shapes + ] + # 执行 all_gather,确保所有进程发送相同长度的一维数据(需保证 local_patch 展平后长度与 element_counts 一致) + dist.all_gather(gathered_data, local_patch.view(-1).clone(), group=self.current_pp_group) + + # 将一维数据重新调整为目标形状 + for i, (h_part, w_part) in enumerate(all_shapes): + gathered_data[i] = gathered_data[i].view(batch_size, channels, time_steps, h_part, w_part) + + # Reconstruct full tensor + full_tensor = torch.empty( + (batch_size, channels, time_steps, total_h, total_w), + device=local_patch.device, + dtype=local_patch.dtype + ) + + current_row = 0 + for r in range(self.h_split): + current_col = 0 + row_height = row_heights[r] + for c in range(self.w_split): + rank = r * self.w_split + c + h_part, w_part = all_shapes[rank] + + # Place partition in correct position + full_tensor[:, :, :, current_row:current_row+h_part, + current_col:current_col+w_part] = gathered_data[rank] + current_col += col_widths[c] + current_row += row_height + + return full_tensor + + def exchange_columns(self, local_patch, pad=None): + """ + Perform column-wise data exchange with adjacent processes + + Args: + local_patch (torch.Tensor): Local partition tensor + pad (bool): Whether to add zero-padding for edge processes + + Returns: + torch.Tensor: Tensor with exchanged column data + """ + send_ops = [] + recv_ops = [] + left_recv = None + right_recv = None + + if self.w_split > 1: + # Send/receive left column + if self.col_rank > 0: + prev_rank = self.col_to_global_rank[self.col_rank - 1] + left_col = local_patch[..., :, :1].contiguous() + left_recv = torch.empty_like(left_col) + send_ops.append(dist.P2POp(dist.isend, left_col, prev_rank, group=self.row_group)) + recv_ops.append(dist.P2POp(dist.irecv, left_recv, prev_rank, group=self.row_group)) + + # Send/receive right column + if self.col_rank < self.w_split - 1: + next_rank = self.col_to_global_rank[self.col_rank + 1] + right_col = local_patch[..., :, -1:].contiguous() + right_recv = torch.empty_like(right_col) + send_ops.append(dist.P2POp(dist.isend, right_col, next_rank, group=self.row_group)) + recv_ops.append(dist.P2POp(dist.irecv, right_recv, next_rank, group=self.row_group)) + + # Execute communication + reqs = dist.batch_isend_irecv(send_ops + recv_ops) + for req in reqs: + req.wait() + + # Handle padding for edge cases + if pad: + left_pad = torch.zeros_like(local_patch[..., :, :1]) if self.col_rank == 0 else left_recv + right_pad = torch.zeros_like(local_patch[..., :, -1:]) if self.col_rank == self.w_split - 1 else right_recv + return torch.cat([left_pad, local_patch, right_pad], dim=-1).contiguous() + else: + if self.w_split > 1: + if self.col_rank == 0: + return torch.cat([local_patch, right_recv], dim=-1).contiguous() + elif self.col_rank == self.w_split - 1: + return torch.cat([left_recv, local_patch], dim=-1).contiguous() + else: + return torch.cat([left_recv, local_patch, right_recv], dim=-1).contiguous() + else: + return local_patch + + def exchange_rows(self, local_patch, pad=None): + """ + Perform row-wise data exchange with adjacent processes + + Args: + local_patch (torch.Tensor): Local partition tensor + pad (bool): Whether to add zero-padding for edge processes + + Returns: + torch.Tensor: Tensor with exchanged row data + """ + send_ops = [] + recv_ops = [] + top_recv = None + bottom_recv = None + + if self.h_split > 1: + # Send/receive top row + if self.row_rank > 0: + prev_rank = self.row_to_global_rank[self.row_rank - 1] + top_row = local_patch[..., :1, :].contiguous() + top_recv = torch.empty_like(top_row) + send_ops.append(dist.P2POp(dist.isend, top_row, prev_rank, group=self.col_group)) + recv_ops.append(dist.P2POp(dist.irecv, top_recv, prev_rank, group=self.col_group)) + + # Send/receive bottom row + if self.row_rank < self.h_split - 1: + next_rank = self.row_to_global_rank[self.row_rank + 1] + bottom_row = local_patch[..., -1:, :].contiguous() + bottom_recv = torch.empty_like(bottom_row) + send_ops.append(dist.P2POp(dist.isend, bottom_row, next_rank, group=self.col_group)) + recv_ops.append(dist.P2POp(dist.irecv, bottom_recv, next_rank, group=self.col_group)) + + # Execute communication + reqs = dist.batch_isend_irecv(send_ops + recv_ops) + for req in reqs: + req.wait() + + # Handle padding for edge cases + if pad: + top_pad = torch.zeros_like(local_patch[..., :1, :]) if self.row_rank == 0 else top_recv + bottom_pad = torch.zeros_like(local_patch[..., -1:, :]) if self.row_rank == self.h_split - 1 else bottom_recv + return torch.cat([top_pad, local_patch, bottom_pad], dim=-2).contiguous() + else: + if self.h_split > 1: + if self.row_rank == 0: + return torch.cat([local_patch, bottom_recv], dim=-2).contiguous() + elif self.row_rank == self.h_split - 1: + return torch.cat([top_recv, local_patch], dim=-2).contiguous() + else: + return torch.cat([top_recv, local_patch, bottom_recv], dim=-2).contiguous() + else: + return local_patch + + def wraps_f_conv3d(self, f_conv3d=F.conv3d): + """ + Decorator to handle distributed 3D convolution with padding + + Args: + f_conv3d: Original convolution function + + Returns: + Wrapped convolution function with distributed padding handling + """ + self.ori_conv3d = f_conv3d + + def wrapped_conv3d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1): + # Process padding parameters + if isinstance(padding, int): + padding = (padding, padding, padding) + else: + padding = tuple(padding) + if len(padding) != 3: + raise ValueError("padding must be an int or a 3-element tuple") + + # Validate parameters + if padding[-1] not in {0, 1} or padding[-2] not in {0, 1}: + raise NotImplementedError("Only support padding[1]/padding[2] as 0 or 1") + if not all(s == 1 for s in (stride[-2:] if isinstance(stride, tuple) else (stride,))): + raise NotImplementedError("Only support stride=1 for dim H, W") + if not all(d == 1 for d in (dilation if isinstance(dilation, tuple) else (dilation,))): + raise NotImplementedError("Only support dilation=1") + + # Validate kernel size and padding relationship [[3]][[6]] + kernel_size = weight.shape[2:5] # Get kernel dimensions (depth, height, width) + if padding[1] * 2 + 1 != kernel_size[1] or padding[2] * 2 + 1 != kernel_size[2]: + raise ValueError( + f"3D Convolution requires: " + f"padding[1]*2+1 == kernel_size[1] and padding[2]*2+1 == kernel_size[2]. " + f"Got padding={padding}, kernel_size={kernel_size}" + ) + + # Handle row and column exchanges for padding + if padding[-2] == 1: + input = self.exchange_rows(input, pad=True) + if padding[-1] == 1: + input = self.exchange_columns(input, pad=True) + + # Call original convolution with adjusted padding + return self.ori_conv3d(input, weight, bias, stride=stride, padding=(padding[0],0,0), + dilation=1, groups=groups) + return wrapped_conv3d + + def wraps_f_conv2d(self, f_conv2d=F.conv2d): + """ + Decorator to handle distributed 2D convolution with padding + + Args: + f_conv2d: Original 2D convolution function + + Returns: + Wrapped 2D convolution function with distributed padding handling + """ + self.ori_conv2d = f_conv2d + + def wrapped_conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1): + + # Handle stride parameter + if not isinstance(stride, tuple): + stride = (stride, stride) # Convert to tuple if not already + + if not all(s == 1 for s in stride): + # Dispatch input if any stride value is not 1 + input = self.dispatch(input.unsqueeze(2)).squeeze(2) + + # Dynamically calculate the split range + total_out_channels = weight.size(0) + base = total_out_channels // self.world_size + remainder = total_out_channels % self.world_size + + # Record the number of channels assigned to each device + channels_per_rank = [ + base + (1 if r < remainder else 0) for r in range(self.world_size) + ] + + # Current process channel range + start = sum(channels_per_rank[:self.rank]) + end = start + channels_per_rank[self.rank] + + weight_chunk = weight.narrow(0, start, end - start) + bias_chunk = bias.narrow(0, start, end - start) if bias is not None else None + + # Call original convolution with adjusted parameters + output = self.ori_conv2d( + input, weight_chunk, bias_chunk, stride, padding, dilation, groups) + + # On r-th NPU output [B, C/N_r, H, W] -> list of [B, C/N_r, H/h_split _i , W/w_split _i] for i = 0 ~ world size-1 + patches = self.patch(output, return_lst=True) + + # Construct the list of receiving shapes + # On i-th NPU [B, C/N_r, H/h_split _i , W/w_split _i] , for r = 0 ~ world size-1 + h_part, w_part = patches[self.rank].shape[-2:] + recv_shapes = [ + (output.shape[0], channels_per_rank[r], h_part, w_part) + for r in range(self.world_size) + ] + # Prepare buffers for all-to-all communication + gathered_outputs = [ + torch.empty(recv_shapes[r], dtype=output.dtype, device=output.device) + for r in range(self.world_size) + ] + + # Perform all-to-all communication to exchange data across processes + dist.all_to_all(gathered_outputs, patches, group=self.current_pp_group) + + # Concatenate gathered outputs along the channel dimension + full_output = torch.cat(gathered_outputs, dim=1) + + return full_output + + else: + + # Process padding parameters + if isinstance(padding, int): + padding = (padding, padding) + else: + padding = tuple(padding) + if len(padding) != 2: + raise ValueError("padding must be an int or a 2-element tuple") + + # Validate parameters + if padding[-1] not in {0, 1} or padding[-2] not in {0, 1}: + raise NotImplementedError("Only support padding values as 0 or 1") + if not (all(s == 1 for s in (stride if isinstance(stride, tuple) else (stride,))) and + all(d == 1 for d in (dilation if isinstance(dilation, tuple) else (dilation,)))): + raise NotImplementedError("Only support stride=1 and dilation=1") + + # Validate kernel size and padding relationship [[8]] + kernel_size = weight.shape[2:4] # Get kernel dimensions (height, width) + if padding[0] * 2 + 1 != kernel_size[0] or padding[1] * 2 + 1 != kernel_size[1]: + raise ValueError( + f"2D Convolution requires: " + f"padding[0]*2+1 == kernel_size[0] and padding[1]*2+1 == kernel_size[1]. " + f"Got padding={padding}, kernel_size={kernel_size}" + ) + + # Handle row and column exchanges for padding + if padding[-2] == 1: + input = self.exchange_rows(input, pad=True) + if padding[-1] == 1: + input = self.exchange_columns(input, pad=True) + + # Call original convolution with adjusted padding + return self.ori_conv2d( + input, weight, bias, + stride=1, + padding=0, + dilation=1, + groups=groups + ) + return wrapped_conv2d + + def wraps_f_interpolate(self, f_interpolate=F.interpolate): + """ + Decorator to handle distributed interpolation operations + + Args: + f_interpolate: Original interpolation function + + Returns: + Wrapped interpolation function with distributed handling + """ + self.ori_interpolate = f_interpolate + + def wrapped_interpolate(input, size=None, scale_factor=None, mode='nearest', + align_corners=None, recompute_scale_factor=None, antialias=False): + # Validate inputs + if not isinstance(input, torch.Tensor): + raise TypeError("Input must be a PyTorch Tensor.") + if scale_factor is None: + raise ValueError("scale_factor must be provided") + + spatial_dims = input.dim() - 2 + if isinstance(scale_factor, int): + scale_factor = (scale_factor,) * spatial_dims + if not isinstance(scale_factor, tuple) or len(scale_factor) != spatial_dims: + raise ValueError(f"scale_factor must be an int or a tuple of length {spatial_dims}") + if any(sf > 2 for sf in scale_factor): + raise ValueError("Scale factors must not exceed 2") + + # Handle supported modes without data exchange + if mode in {"nearest", 'area', 'nearest-exact'}: # + return self.ori_interpolate( + input=input, + size=None, + scale_factor=scale_factor, + mode=mode, + align_corners=align_corners, + recompute_scale_factor=None, + antialias=False + ) + else: + # Handle modes requiring data exchange + use_exchange_rows = scale_factor[-2] == 2 + use_exchange_columns = scale_factor[-1] == 2 + + # Perform data exchange + if use_exchange_columns: + input = self.exchange_columns(input, pad=False) + if use_exchange_rows: + input = self.exchange_rows(input, pad=False) + + # Perform interpolation + output = self.ori_interpolate( + input=input, + size=None, + scale_factor=scale_factor, + mode=mode, + align_corners=align_corners, + recompute_scale_factor=None, + antialias=False + ) + + # Slice excess data + if use_exchange_columns and self.w_split > 1: + if self.col_rank == 0: + output = output[..., :-2] + elif self.col_rank < self.w_split - 1: + output = output[..., 2:-2] + else: + output = output[..., 2:] + + if use_exchange_rows: + if self.row_rank == 0: + output = output[..., :-2, :] + elif self.row_rank < self.h_split - 1: + output = output[..., 2:-2, :] + else: + output = output[..., 2:, :] + return output + return wrapped_interpolate + + def wraps_fa(self, fa, layout="BNSD"): + """ + Decorator for attention functions with distributed key/value handling + + Args: + fa: Original attention function + layout (str): Tensor layout ('BNSD' or 'BSND') + + Returns: + Wrapped attention function with distributed key/value handling + """ + self.ori_fa = fa + self.layout = layout + + def wrapped_fa(q, k, v, *args, **kwargs): + # Validate layout + if self.layout not in {"BNSD", "BSND"}: + raise ValueError("Unsupported layout. Only 'BNSD' and 'BSND' are supported.") + + # Gather key shapes across processes + local_shape = torch.tensor(k.shape, device=k.device) + all_shapes = [torch.empty_like(local_shape) for _ in range(self.world_size)] + dist.all_gather(all_shapes, local_shape, group=self.current_pp_group) + all_shapes = [tuple(shape.tolist()) for shape in all_shapes] + + # Prepare buffers for full keys/values + gathered_k = [torch.empty(shape, dtype=k.dtype, device=k.device) for shape in all_shapes] + gathered_v = [torch.empty_like(k_tensor) for k_tensor in gathered_k] + + # Gather full keys and values + dist.all_gather(gathered_k, k.contiguous(), group=self.current_pp_group) + dist.all_gather(gathered_v, v.contiguous(), group=self.current_pp_group) + + # Concatenate along sequence dimension + if layout == "BNSD": + full_k = torch.cat(gathered_k, dim=2) + full_v = torch.cat(gathered_v, dim=2) + else: + full_k = torch.cat(gathered_k, dim=1) + full_v = torch.cat(gathered_v, dim=1) + + # Call original attention function + return self.ori_fa(q, full_k, full_v, *args, **kwargs) + return wrapped_fa + + def wraps_decoder_fw(self, decoder_fw): + def wrapped_decoder_fw(input, *args,**kwargs): + input = self.patch(input) + output = decoder_fw(input, *args,**kwargs) + return self.dispatch(output) + return wrapped_decoder_fw + + def wraps_f_pad(self, f_pad=F.pad): + self.ori_pad = f_pad + def wrapped_pad(input, pad, mode='constant', value=None): + len_pad = len(pad) + if len_pad % 2 != 0: + raise ValueError("Padding length must be even-valued") + adapted_pad = list(pad) + if len_pad >1: + # Handle horizontal direction (left/right) + if self.w_split == 1: + # Apply full left/right padding when single slice + adapted_pad[0] = pad[0] + adapted_pad[1] = pad[1] + else: + # Apply pad[0], pad[1] to the left and right boundary + if self.col_rank == 0: + adapted_pad[0] = pad[0] + adapted_pad[1] = 0 + elif self.col_rank == self.w_split - 1: + adapted_pad[0] = 0 + adapted_pad[1] = pad[1] + else: + adapted_pad[0] = 0 + adapted_pad[1] = 0 + if len_pad > 3: + # Handle vertical direction (top/bottom) + if self.h_split == 1: + # Apply full top/bottom padding when single slice + adapted_pad[2] = pad[2] + adapted_pad[3] = pad[3] + else: + # Apply pad[2], pad[3] to the top and bottom boundary + if self.row_rank == 0: + adapted_pad[2] = pad[2] + adapted_pad[3] = 0 + elif self.row_rank == self.h_split - 1: + adapted_pad[2] = 0 + adapted_pad[3] = pad[3] + else: + adapted_pad[2] = 0 + adapted_pad[3] = 0 + + return self.ori_pad(input, tuple(adapted_pad), mode=mode, value=value) + return wrapped_pad + +VAE_PATCH_PARALLEL = None +FA_LAYOUT = None + +def set_vae_patch_parallel(vae,h_split=1, w_split=1, fa_layout="BNSD",decoder_decode="decoder.forward", + all_pp_group_ranks=None, **kwargs): + global VAE_PATCH_PARALLEL + global FA_LAYOUT + if VAE_PATCH_PARALLEL is None: + VAE_PATCH_PARALLEL = Parallel_VAE_SP(h_split, w_split, all_pp_group_ranks) + FA_LAYOUT = fa_layout + + # wraps_decoder_fw + decoder_decode_lst = decoder_decode.split(".") + # the function + ori_decoder_decode_func = reduce(getattr, decoder_decode_lst, vae) + # the name of the function + decoder_decode_func = decoder_decode_lst.pop() + ori_vae_decoder = reduce(getattr, decoder_decode_lst, vae) + + new_decoder_decode = VAE_PATCH_PARALLEL.wraps_decoder_fw(ori_decoder_decode_func) + setattr(ori_vae_decoder, decoder_decode_func, new_decoder_decode) + return ori_decoder_decode_func + +def get_vae_patch_parallel(): + return VAE_PATCH_PARALLEL + +class VAE_patch_parallel: + def __init__(self): + global VAE_PATCH_PARALLEL + self.vae_pp_cls = VAE_PATCH_PARALLEL + def __enter__(self): + if self.vae_pp_cls is not None: + self._sub_F_func() + self._sub_FA() + + def __exit__(self,t,v,trace): + if self.vae_pp_cls is not None: + self._revert_F_func() + self._revert_FA() + + def _sub_F_func(self): + F.conv3d = self.vae_pp_cls.wraps_f_conv3d(F.conv3d) + F.conv2d = self.vae_pp_cls.wraps_f_conv2d(F.conv2d) + F.interpolate = self.vae_pp_cls.wraps_f_interpolate(F.interpolate) + F.pad = self.vae_pp_cls.wraps_f_pad(F.pad) + + def _sub_FA(self): + global FA_LAYOUT + F.scaled_dot_product_attention = self.vae_pp_cls.wraps_fa( + F.scaled_dot_product_attention, layout=FA_LAYOUT) + + def _revert_F_func(self): + """Restore original PyTorch functions after context exit""" + if self.vae_pp_cls.ori_conv3d is not None: + F.conv3d = self.vae_pp_cls.ori_conv3d + if self.vae_pp_cls.ori_conv2d is not None: + F.conv2d = self.vae_pp_cls.ori_conv2d + if self.vae_pp_cls.ori_interpolate is not None: + F.interpolate = self.vae_pp_cls.ori_interpolate + if self.vae_pp_cls.ori_pad is not None: + F.pad = self.vae_pp_cls.ori_pad + + def _revert_FA(self): + """Restore original attention function after context exit""" + if self.vae_pp_cls.ori_fa is not None: + F.scaled_dot_product_attention = self.vae_pp_cls.ori_fa \ No newline at end of file diff --git a/diffsynth/npu_utils/modules/attn_layer.py b/diffsynth/npu_utils/modules/attn_layer.py new file mode 100644 index 0000000000000000000000000000000000000000..fb48f159d3899c1f7ed6491d0b505e0125257068 --- /dev/null +++ b/diffsynth/npu_utils/modules/attn_layer.py @@ -0,0 +1,169 @@ +import logging +import torch +from torch import Tensor +import torch_npu +import torch.distributed as dist +import math +import os +from yunchang import LongContextAttention +try: + from yunchang.kernels import AttnType +except ImportError: + raise ImportError("Please install yunchang 0.6.0 or later") +from typing import Any + +from mindiesd.layers.flash_attn.attention_forward import attention_forward + +from ..distributed.parallel_mgr import get_sp_group +from ..distributed.comm import all_to_all_4D + +logger = logging.getLogger(__name__) +MAX_TOKEN = 2147483647 + +class xFuserLongContextAttention(LongContextAttention): + ring_impl_type_supported_kv_cache = ["basic"] + + def __init__( + self, + args: Any = None, + scatter_idx: int = 2, + gather_idx: int = 1, + ring_impl_type: str = "basic", + use_pack_qkv: bool = False, + use_kv_cache: bool = False, + attn_type: AttnType = AttnType.FA, + ) -> None: + """ + Arguments: + scatter_idx: int = 2, the scatter dimension index for Ulysses All2All + gather_idx: int = 1, the gather dimension index for Ulysses All2All + ring_impl_type: str = "basic", the ring implementation type, currently only support "basic" + use_pack_qkv: bool = False, whether to use pack qkv in the input + use_kv_cache: bool = False, whether to use kv cache in the attention layer, which is applied in PipeFusion. + """ + super().__init__( + scatter_idx=scatter_idx, + gather_idx=gather_idx, + ring_impl_type=ring_impl_type, + use_pack_qkv=use_pack_qkv, + attn_type = attn_type, + ) + self.use_kv_cache = use_kv_cache + if ( + use_kv_cache + and ring_impl_type not in self.ring_impl_type_supported_kv_cache + ): + raise RuntimeError( + f"ring_impl_type: {ring_impl_type} do not support SP kv cache." + ) + self.world_size = dist.get_world_size() + self.args = args + self.video_size = ['480*832', '832*480', '480*720', '720*480'] + + self.algo = int(os.getenv('ALGO', 0)) + + """ + if self.args.size in self.video_size: + self.use_all_head = True + else: + self.use_all_head = False + """ + + self.ulysses_pg = get_sp_group().ulysses_group + self.ring_pg = get_sp_group().ring_group + + def forward( + self, + attn, + query: Tensor, + key: Tensor, + value: Tensor, + *, + joint_tensor_query=None, + joint_tensor_key=None, + joint_tensor_value=None, + dropout_p=0.0, + softmax_scale=None, + causal=False, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=False, + return_attn_probs=False, + joint_strategy="none", + scale=None + ) -> Tensor: + """forward + + Arguments: + attn (Attention): the attention module + query (Tensor): query input to the layer + key (Tensor): key input to the layer + value (Tensor): value input to the layer + args: other args, + joint_tensor_query: Tensor = None, a replicated tensor among processes appended to the front or rear of query, depends the joint_strategy + joint_tensor_key: Tensor = None, a replicated tensor among processes appended to the front or rear of key, depends the joint_strategy + joint_tensor_value: Tensor = None, a replicated tensor among processes appended to the front or rear of value, depends the joint_strategy, + *args: the args same as flash_attn_interface + joint_strategy: str = "none", the joint strategy for joint attention, currently only support "front" and "rear" + + Returns: + * output (Tensor): context output + """ + + query_layer = all_to_all_4D(input_=query, scatter_idx=2, gather_idx=1, group=self.ulysses_pg) + key_layer = all_to_all_4D(input_=key, scatter_idx=2, gather_idx=1, group=self.ulysses_pg) + value_layer = all_to_all_4D(input_=value, scatter_idx=2, gather_idx=1, group=self.ulysses_pg) + + if get_sp_group().ring_world_size > 1: + ring_size = get_sp_group().ring_world_size + b, s, n, d = key_layer.shape + k_full = torch.empty([ring_size, b, s, n, d], dtype=query_layer.dtype, device=query_layer.device) + dist.all_gather_into_tensor(k_full, key_layer, group=self.ring_pg) + key_layer = k_full.permute(1, 0, 2, 3, 4).reshape(b, -1, n, d) + + v_full = torch.empty([ring_size, b, s, n, d], dtype=query_layer.dtype, device=query_layer.device) + dist.all_gather_into_tensor(v_full, value_layer, group=self.ring_pg) + value_layer = v_full.permute(1, 0, 2, 3, 4).reshape(b, -1, n, d) + + + # if self.use_all_head: + try: + if self.algo == 0: + out = attention_forward(query_layer, key_layer, value_layer, + opt_mode="manual", op_type="fused_attn_score", layout="BNSD") + elif self.algo == 1: + out = attention_forward(query_layer, key_layer, value_layer, + opt_mode="manual", op_type="ascend_laser_attention", layout="BNSD") + else: + raise ValueError(f"select flash attention algorithm only support 0, 1, but got {self.algo}") + # else: + except: + query_layer_list = query_layer.split(1, dim=2) + key_layer_list = key_layer.split(1, dim=2) + value_layer_list = value_layer.split(1, dim=2) + output = [] + for_loop = query_layer.shape[2] + for i in range(for_loop): + if self.algo == 0: + out = attention_forward(query_layer_list[i], key_layer_list[i], value_layer_list[i], + opt_mode="manual", op_type="fused_attn_score", layout="BNSD") + elif self.algo == 1: + out = attention_forward(query_layer_list[i], key_layer_list[i], value_layer_list[i], + opt_mode="manual", op_type="ascend_laser_attention", layout="BNSD") + else: + raise ValueError(f"select flash attention algorithm only support 0, 1, but got f{self.algo}") + + output.append(out) + out = torch.cat(output, dim=2) + + if type(out) == tuple: + context_layer, _, _ = out + else: + context_layer = out + + # (bs, seq_len, head_cnt/N, head_size) -> (bs, seq_len/N, head_cnt, head_size) + # scatter 1, gather 2 + output = all_to_all_4D(input_=context_layer, scatter_idx=1, gather_idx=2, group=self.ulysses_pg) + + return output + diff --git a/diffsynth/pipelines/base.py b/diffsynth/pipelines/base.py index 2a4f01cff55dc0fcca02dc5234227bd65efc7434..81e6723be82c8650ea07306a748ba8c19bfa7186 100644 --- a/diffsynth/pipelines/base.py +++ b/diffsynth/pipelines/base.py @@ -118,7 +118,8 @@ class BasePipeline(torch.nn.Module): else: model.to(self.device) # fresh the cuda cache - torch.cuda.empty_cache() + if not torch.npu.is_available(): + torch.cuda.empty_cache() def generate_noise(self, shape, seed=None, device="cpu", dtype=torch.float16): diff --git a/diffsynth/pipelines/wan_video.py b/diffsynth/pipelines/wan_video.py index e70e0cc246e2d41ffd19056fd42060144ab6230c..6536dee2ed36703afa6ef8d85a4590e1c53a96f1 100644 --- a/diffsynth/pipelines/wan_video.py +++ b/diffsynth/pipelines/wan_video.py @@ -184,7 +184,10 @@ class WanVideoPipeline(BasePipeline): pipe = WanVideoPipeline(device=device, torch_dtype=torch_dtype) pipe.fetch_models(model_manager) if use_usp: - from xfuser.core.distributed import get_sequence_parallel_world_size + if torch.npu.is_available(): + from diffsynth.npu_utils.distributed.parallel_mgr import get_sequence_parallel_world_size + else: + from xfuser.core.distributed import get_sequence_parallel_world_size from ..distributed.xdit_context_parallel import usp_attn_forward, usp_dit_forward for block in pipe.dit.blocks: @@ -553,9 +556,15 @@ def model_fn_wan_video( ): if use_unified_sequence_parallel: import torch.distributed as dist - from xfuser.core.distributed import (get_sequence_parallel_rank, - get_sequence_parallel_world_size, - get_sp_group) + if torch.npu.is_available(): + from diffsynth.npu_utils.distributed.parallel_mgr import ( + get_sequence_parallel_rank, + get_sequence_parallel_world_size, + get_sp_group) + else: + from xfuser.core.distributed import (get_sequence_parallel_rank, + get_sequence_parallel_world_size, + get_sp_group) t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep)) t_mod = dit.time_projection(t).unflatten(1, (6, dit.dim)) diff --git a/diffsynth/pipelines/wan_video_new.py b/diffsynth/pipelines/wan_video_new.py index 2317422a1c6606ff37f36a04660baaae7794538f..015b5ed934c22c17e42bc68a00749709ad277d36 100644 --- a/diffsynth/pipelines/wan_video_new.py +++ b/diffsynth/pipelines/wan_video_new.py @@ -1,4 +1,9 @@ +import os +import time +import logging import torch, warnings, glob, os, types +import torch_npu +import torch.distributed as dist import numpy as np from PIL import Image from einops import repeat, reduce @@ -11,12 +16,13 @@ from PIL import Image from tqdm import tqdm from typing import Optional from typing_extensions import Literal +import torch.distributed as dist from ..utils import BasePipeline, ModelConfig, PipelineUnit, PipelineUnitRunner from ..models import ModelManager, load_state_dict from ..models.wan_video_dit import WanModel, RMSNorm, sinusoidal_embedding_1d from ..models.wan_video_text_encoder import WanTextEncoder, T5RelativeEmbedding, T5LayerNorm -from ..models.wan_video_vae import WanVideoVAE, RMS_norm, CausalConv3d, Upsample +from ..models.wan_video_vae import WanVideoVAE, RMS_norm, CausalConv3d, Upsample, enable_vae_patch_parallel from ..models.wan_video_image_encoder import WanImageEncoder from ..models.wan_video_vace import VaceWanModel from ..models.wan_video_motion_controller import WanMotionControllerModel @@ -25,11 +31,17 @@ from ..prompters import WanPrompter from ..vram_management import enable_vram_management, AutoWrappedModule, AutoWrappedLinear, WanAutoCastLayerNorm from ..lora import GeneralLoRALoader +from ..npu_utils.distributed.vae_patch_parallel import set_vae_patch_parallel, VAE_patch_parallel +try: + from mindiesd import rotary_position_embedding + NPU_ROPE_AVAILABLE = True +except: + NPU_ROPE_AVAILABLE = False class WanVideoPipeline(BasePipeline): - def __init__(self, device="cuda", torch_dtype=torch.bfloat16, tokenizer_path=None): + def __init__(self, device="cuda", torch_dtype=torch.bfloat16, tokenizer_path=None, **kwargs): super().__init__( device=device, torch_dtype=torch_dtype, height_division_factor=16, width_division_factor=16, time_division_factor=4, time_division_remainder=1 @@ -64,7 +76,19 @@ class WanVideoPipeline(BasePipeline): WanVideoUnit_CfgMerger(), ] self.model_fn = model_fn_wan_video - + + if torch.npu.is_available(): + self.ulysses_size = kwargs['ulysses_size'] if 'ulysses_size' in kwargs else 1 + self.ring_size = kwargs['ring_size'] if 'ring_size' in kwargs else 1 + self.cfg_size = kwargs['cfg_size'] if 'cfg_size' in kwargs else 1 + self.tp_size = kwargs['tp_size'] if 'tp_size' in kwargs else 1 + self.t5_fsdp = bool(kwargs['t5_fsdp']) if 't5_fsdp' in kwargs else False + self.dit_fsdp = bool(kwargs['dit_fsdp']) if 'dit_fsdp' in kwargs else False + self.vae_parallel = bool(kwargs['vae_parallel']) if 'vae_parallel' in kwargs else False + + assert self.tp_size == 1, f"Tensor parallel is not supported on NPU yet. The tp_size is {self.tp_size}." + assert self.cfg_size == 1, f"Classifier free guidance is not supported on NPU yet. The cfg_size is {self.cfg_size}." + assert self.dit_fsdp == False, f"dit_fsdp is not supported on NPU yet. The dit_fsdp is {self.dit_fsdp}." def load_lora(self, module, path, alpha=1): loader = GeneralLoRALoader(torch_dtype=self.torch_dtype, device=self.device) @@ -108,9 +132,9 @@ class WanVideoPipeline(BasePipeline): }, module_config = dict( offload_dtype=dtype, - offload_device="cpu", + offload_device=self.device if torch.npu.is_available() else "cpu", onload_dtype=dtype, - onload_device="cpu", + onload_device=self.device if torch.npu.is_available() else "cpu", computation_dtype=self.torch_dtype, computation_device=self.device, ), @@ -193,9 +217,9 @@ class WanVideoPipeline(BasePipeline): }, module_config = dict( offload_dtype=dtype, - offload_device="cpu", + offload_device=self.device if torch.npu.is_available() else "cpu", onload_dtype=dtype, - onload_device=self.device, + onload_device=self.device if torch.npu.is_available() else "cpu", computation_dtype=self.torch_dtype, computation_device=self.device, ), @@ -211,9 +235,9 @@ class WanVideoPipeline(BasePipeline): }, module_config = dict( offload_dtype=dtype, - offload_device="cpu", + offload_device=self.device if torch.npu.is_available() else "cpu", onload_dtype=dtype, - onload_device="cpu", + onload_device=self.device if torch.npu.is_available() else "cpu", computation_dtype=dtype, computation_device=self.device, ), @@ -257,22 +281,53 @@ class WanVideoPipeline(BasePipeline): def initialize_usp(self): - import torch.distributed as dist - from xfuser.core.distributed import initialize_model_parallel, init_distributed_environment - dist.init_process_group(backend="nccl", init_method="env://") - init_distributed_environment(rank=dist.get_rank(), world_size=dist.get_world_size()) - initialize_model_parallel( - sequence_parallel_degree=dist.get_world_size(), - ring_degree=1, - ulysses_degree=dist.get_world_size(), - ) - torch.cuda.set_device(dist.get_rank()) - + if torch.npu.is_available(): + from diffsynth.npu_utils.distributed.parallel_mgr import init_parallel_env, ParallelConfig + dist.init_process_group(backend="hccl", init_method="env://") + self.device = torch.device(f"npu:{os.getenv('RANK')}") + torch.cuda.set_device(dist.get_rank()) + + world_size = dist.get_world_size() + assert world_size <= 8, f"The world size should be less than or equal to 8." + if self.ulysses_size != world_size: + self.ulysses_size = world_size + + if self.ulysses_size == 1 and self.ring_size == 1 and self.cfg_size == 1 and self.tp_size == 1: + logging.warning("ulysses_size, ring_size, cfg_size, tp_size are not specified, using world_size as ulysses_size") + self.ulysses_size = dist.get_world_size() + assert self.cfg_size * self.ulysses_size * self.ring_size * self.tp_size == world_size, f"The number of cfg_size, ulysses_size, ring_size and tp_size should be equal to the world size." + + sp_degree = self.ulysses_size * self.ring_size + parallel_config = ParallelConfig( + sp_degree=sp_degree, + ulysses_degree=self.ulysses_size, + ring_degree=self.ring_size, + tp_degree=self.tp_size, + use_cfg_parallel=(self.cfg_size==2), + world_size=world_size, + ) + init_parallel_env(parallel_config) + + else: + from xfuser.core.distributed import initialize_model_parallel, init_distributed_environment + dist.init_process_group(backend="nccl", init_method="env://") + init_distributed_environment(rank=dist.get_rank(), world_size=dist.get_world_size()) + initialize_model_parallel( + sequence_parallel_degree=dist.get_world_size(), + ring_size=1, + ulysses_size=dist.get_world_size(), + ) + torch.cuda.set_device(dist.get_rank()) + def enable_usp(self): - from xfuser.core.distributed import get_sequence_parallel_world_size - from ..distributed.xdit_context_parallel import usp_attn_forward, usp_dit_forward + if torch.npu.is_available(): + from diffsynth.npu_utils.distributed.parallel_mgr import get_sequence_parallel_world_size + else: + from xfuser.core.distributed import get_sequence_parallel_world_size + from ..distributed.xdit_context_parallel import usp_attn_forward, usp_dit_forward, usp_dit_forward_vace + # USP DiT for block in self.dit.blocks: block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn) self.dit.forward = types.MethodType(usp_dit_forward, self.dit) @@ -280,6 +335,13 @@ class WanVideoPipeline(BasePipeline): for block in self.dit2.blocks: block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn) self.dit2.forward = types.MethodType(usp_dit_forward, self.dit2) + + # USP VACE + if self.vace is not None: + for block in self.vace.vace_blocks: + block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn) + self.vace.forward = types.MethodType(usp_dit_forward_vace, self.vace) + self.sp_size = get_sequence_parallel_world_size() self.use_unified_sequence_parallel = True @@ -292,7 +354,10 @@ class WanVideoPipeline(BasePipeline): tokenizer_config: ModelConfig = ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/*"), redirect_common_files: bool = True, use_usp=False, + **kwargs ): + if torch.npu.is_available() and 'vae_parallel' in kwargs and kwargs.get("vae_parallel", True): + enable_vae_patch_parallel() # Redirect model path if redirect_common_files: redirect_dict = { @@ -308,8 +373,23 @@ class WanVideoPipeline(BasePipeline): model_config.model_id = redirect_dict[model_config.origin_file_pattern] # Initialize pipeline - pipe = WanVideoPipeline(device=device, torch_dtype=torch_dtype) - if use_usp: pipe.initialize_usp() + pipe = WanVideoPipeline(device=device, torch_dtype=torch_dtype, **kwargs) + if use_usp: + pipe.initialize_usp() + else: + assert not ( + pipe.t5_fsdp and pipe.dit_fsdp + ), f"t5_fsdp and dit_fsdp are not supported in non-distributed environments." + assert not ( + pipe.ulysses_size > 1 or pipe.ring_size > 1 or pipe.cfg_size > 1 or pipe.tp_size > 1 + ), f"context parallel is not supported in non-distributed environments." + assert not ( + pipe.vae_parallel + ), f"vae parallel are not supported in non-distributed environments." + + if pipe.tp_size > 1 and pipe.dit_fsdp: + logging.warning("Tensor parallel is not supported in dit_fsdp mode, setting dit_fsdp to False") + pipe.dit_fsdp = False # Download and load models model_manager = ModelManager() @@ -320,15 +400,37 @@ class WanVideoPipeline(BasePipeline): device=model_config.offload_device or device, torch_dtype=model_config.offload_dtype or torch_dtype ) + print(f"[Device: {int(os.getenv('RANK'))}] Finish loading models. Memory usage in GB: {torch.npu.memory_allocated(device=device) / 1024 / 1024 / 1024}") # Load models pipe.text_encoder = model_manager.fetch_model("wan_video_text_encoder") + + if pipe.t5_fsdp or pipe.dit_fsdp: + from diffsynth.npu_utils.distributed.fsdp import shard_model + from functools import partial + shard_fn = partial(shard_model, device_id=torch.device(f"npu:{dist.get_rank()}")) + + if pipe.t5_fsdp: + pipe.text_encoder = shard_fn(pipe.text_encoder, sync_module_states=False) + dit = model_manager.fetch_model("wan_video_dit", index=2) if isinstance(dit, list): pipe.dit, pipe.dit2 = dit else: pipe.dit = dit + if pipe.dit_fsdp: + pipe.dit = shard_fn(pipe.dit, sync_module_states=False) + if pipe.dit2 is not None: + pipe.dit2 = shard_fn(pipe.dit2, sync_module_states=False) + pipe.vae = model_manager.fetch_model("wan_video_vae") + if pipe.vae_parallel: + all_pp_group_ranks = [] + for i in range(0, dist.get_world_size() // 8): + all_pp_group_ranks.append(list(range(8 * i, 8 * (i + 1)))) + set_vae_patch_parallel(pipe.vae.model, 4, 2, all_pp_group_ranks= all_pp_group_ranks, decoder_decode="decoder.forward") + set_vae_patch_parallel(pipe.vae.model, 4, 2, all_pp_group_ranks= all_pp_group_ranks, decoder_decode="encoder.forward") + pipe.image_encoder = model_manager.fetch_model("wan_video_image_encoder") pipe.motion_controller = model_manager.fetch_model("wan_video_motion_controller") pipe.vace = model_manager.fetch_model("wan_video_vace") @@ -345,6 +447,8 @@ class WanVideoPipeline(BasePipeline): # Unified Sequence Parallel if use_usp: pipe.enable_usp() + + print(f"Finish loading models. Memory usage in GB: {torch.npu.memory_allocated(device=device) / 1024 / 1024 / 1024}") return pipe @@ -403,6 +507,10 @@ class WanVideoPipeline(BasePipeline): # progress_bar progress_bar_cmd=tqdm, ): + if self.vae_parallel and tiled: + if int(os.getenv("RANK", 0)) == 0: + print("vae_parallel and tiled are not supported together. Setting tiled to False.") + tiled = False # Scheduler self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, shift=sigma_shift) @@ -430,46 +538,117 @@ class WanVideoPipeline(BasePipeline): "tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride, "sliding_window_size": sliding_window_size, "sliding_window_stride": sliding_window_stride, } + encode_time = time.time() for unit in self.units: inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega) - + encode_time = time.time() - encode_time + if int(os.getenv("RANK", 0)) == 0: + print(f"Encode time: {encode_time}") + # Denoise self.load_models_to_device(self.in_iteration_models) models = {name: getattr(self, name) for name in self.in_iteration_models} - for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)): - # Switch DiT if necessary - if timestep.item() < switch_DiT_boundary * self.scheduler.num_train_timesteps and self.dit2 is not None and not models["dit"] is self.dit2: - self.load_models_to_device(self.in_iteration_models_2) - models["dit"] = self.dit2 - - # Timestep - timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device) - - # Inference - noise_pred_posi = self.model_fn(**models, **inputs_shared, **inputs_posi, timestep=timestep) - if cfg_scale != 1.0: - if cfg_merge: - noise_pred_posi, noise_pred_nega = noise_pred_posi.chunk(2, dim=0) + + + dit_time = time.time() + if os.getenv("PROFILING_ENABLE", "0") == "1": + profiling_dir = os.getenv("PROFILING_DIR", "./prof") + experimental_config = torch_npu.profiler._ExperimentalConfig( + export_type=torch_npu.profiler.ExportType.Text, + aic_metrics=torch_npu.profiler.AiCMetrics.PipeUtilization, + profiler_level=torch_npu.profiler.ProfilerLevel.Level1, + l2_cache=False, + data_simplification=False + ) + with torch_npu.profiler.profile( + activities=[torch_npu.profiler.ProfilerActivity.CPU, torch_npu.profiler.ProfilerActivity.NPU], + with_stack=True, + record_shapes=True, + profile_memory=True, + schedule=torch_npu.profiler.schedule(wait=2, warmup=2, active=1, repeat=1, skip_first=10), + experimental_config=experimental_config, + on_trace_ready=torch_npu.profiler.tensorboard_trace_handler(profiling_dir) + ) as prof: + for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)): + # Switch DiT if necessary + if timestep.item() < switch_DiT_boundary * self.scheduler.num_train_timesteps and self.dit2 is not None and not models["dit"] is self.dit2: + self.load_models_to_device(self.in_iteration_models_2) + models["dit"] = self.dit2 + + # Timestep + timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device) + # timestep = torch.stack([timestep]) + + # Inference + noise_pred_posi = self.model_fn(**models, **inputs_shared, **inputs_posi, timestep=timestep) + if cfg_scale != 1.0: + if cfg_merge: + noise_pred_posi, noise_pred_nega = noise_pred_posi.chunk(2, dim=0) + else: + noise_pred_nega = self.model_fn(**models, **inputs_shared, **inputs_nega, timestep=timestep) + noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega) + else: + noise_pred = noise_pred_posi + + # Scheduler + inputs_shared["latents"] = self.scheduler.step(noise_pred, timestep, inputs_shared["latents"]) + if "first_frame_latents" in inputs_shared: + inputs_shared["latents"][:, :, 0:1] = inputs_shared["first_frame_latents"] + + prof.step() + else: + for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)): + # Switch DiT if necessary + if timestep.item() < switch_DiT_boundary * self.scheduler.num_train_timesteps and self.dit2 is not None and not models["dit"] is self.dit2: + self.load_models_to_device(self.in_iteration_models_2) + models["dit"] = self.dit2 + + # Timestep + timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device) + + # Inference + noise_pred_posi = self.model_fn(**models, **inputs_shared, **inputs_posi, timestep=timestep) + if cfg_scale != 1.0: + if cfg_merge: + noise_pred_posi, noise_pred_nega = noise_pred_posi.chunk(2, dim=0) + else: + noise_pred_nega = self.model_fn(**models, **inputs_shared, **inputs_nega, timestep=timestep) + noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega) else: - noise_pred_nega = self.model_fn(**models, **inputs_shared, **inputs_nega, timestep=timestep) - noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega) - else: - noise_pred = noise_pred_posi + noise_pred = noise_pred_posi + + # Scheduler + inputs_shared["latents"] = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], inputs_shared["latents"]) + if "first_frame_latents" in inputs_shared: + inputs_shared["latents"][:, :, 0:1] = inputs_shared["first_frame_latents"] + dit_time = time.time() - dit_time + if int(os.getenv("RANK", 0)) == 0: + print(f"Dit time: {dit_time}") - # Scheduler - inputs_shared["latents"] = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], inputs_shared["latents"]) - if "first_frame_latents" in inputs_shared: - inputs_shared["latents"][:, :, 0:1] = inputs_shared["first_frame_latents"] - # VACE (TODO: remove it) if vace_reference_image is not None: inputs_shared["latents"] = inputs_shared["latents"][:, :, 1:] # Decode + decode_time = time.time() self.load_models_to_device(['vae']) - video = self.vae.decode(inputs_shared["latents"], device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) + with VAE_patch_parallel(): + video = self.vae.decode(inputs_shared["latents"], device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) video = self.vae_output_to_video(video) self.load_models_to_device([]) + decode_time = time.time() - decode_time + if int(os.getenv("RANK", 0)) == 0: + print(f"Decode time: {decode_time}") + + del inputs_shared + del inputs_posi + del inputs_nega + if self.dit is not None: + del self.dit.freqs_list + self.dit.freqs_list = None + if self.dit2 is not None: + del self.dit2.freqs_list + self.dit2.freqs_list = None return video @@ -513,10 +692,12 @@ class WanVideoUnit_InputVideoEmbedder(PipelineUnit): return {"latents": noise} pipe.load_models_to_device(["vae"]) input_video = pipe.preprocess_video(input_video) - input_latents = pipe.vae.encode(input_video, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device) + with VAE_patch_parallel(): + input_latents = pipe.vae.encode(input_video, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device) if vace_reference_image is not None: vace_reference_image = pipe.preprocess_video([vace_reference_image]) - vace_reference_latents = pipe.vae.encode(vace_reference_image, device=pipe.device).to(dtype=pipe.torch_dtype, device=pipe.device) + with VAE_patch_parallel(): + vace_reference_latents = pipe.vae.encode(vace_reference_image, device=pipe.device).to(dtype=pipe.torch_dtype, device=pipe.device) input_latents = torch.concat([vace_reference_latents, input_latents], dim=2) if pipe.scheduler.training: return {"latents": noise, "input_latents": input_latents} @@ -573,7 +754,8 @@ class WanVideoUnit_ImageEmbedder(PipelineUnit): msk = msk.view(1, msk.shape[1] // 4, 4, height//8, width//8) msk = msk.transpose(1, 2)[0] - y = pipe.vae.encode([vae_input.to(dtype=pipe.torch_dtype, device=pipe.device)], device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0] + with VAE_patch_parallel(): + y = pipe.vae.encode([vae_input.to(dtype=pipe.torch_dtype, device=pipe.device)], device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0] y = y.to(dtype=pipe.torch_dtype, device=pipe.device) y = torch.concat([msk, y]) y = y.unsqueeze(0) @@ -630,7 +812,8 @@ class WanVideoUnit_ImageEmbedderVAE(PipelineUnit): msk = msk.view(1, msk.shape[1] // 4, 4, height//8, width//8) msk = msk.transpose(1, 2)[0] - y = pipe.vae.encode([vae_input.to(dtype=pipe.torch_dtype, device=pipe.device)], device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0] + with VAE_patch_parallel(): + y = pipe.vae.encode([vae_input.to(dtype=pipe.torch_dtype, device=pipe.device)], device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0] y = y.to(dtype=pipe.torch_dtype, device=pipe.device) y = torch.concat([msk, y]) y = y.unsqueeze(0) @@ -654,7 +837,8 @@ class WanVideoUnit_ImageEmbedderFused(PipelineUnit): return {} pipe.load_models_to_device(self.onload_model_names) image = pipe.preprocess_image(input_image.resize((width, height))).transpose(0, 1) - z = pipe.vae.encode([image], device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) + with VAE_patch_parallel(): + z = pipe.vae.encode([image], device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) latents[:, :, 0: 1] = z return {"latents": latents, "fuse_vae_embedding_in_latents": True, "first_frame_latents": z} @@ -672,7 +856,8 @@ class WanVideoUnit_FunControl(PipelineUnit): return {} pipe.load_models_to_device(self.onload_model_names) control_video = pipe.preprocess_video(control_video) - control_latents = pipe.vae.encode(control_video, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device) + with VAE_patch_parallel(): + control_latents = pipe.vae.encode(control_video, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device) control_latents = control_latents.to(dtype=pipe.torch_dtype, device=pipe.device) if clip_feature is None or y is None: clip_feature = torch.zeros((1, 257, 1280), dtype=pipe.torch_dtype, device=pipe.device) @@ -697,7 +882,8 @@ class WanVideoUnit_FunReference(PipelineUnit): pipe.load_models_to_device(["vae"]) reference_image = reference_image.resize((width, height)) reference_latents = pipe.preprocess_video([reference_image]) - reference_latents = pipe.vae.encode(reference_latents, device=pipe.device) + with VAE_patch_parallel(): + reference_latents = pipe.vae.encode(reference_latents, device=pipe.device) clip_feature = pipe.preprocess_image(reference_image) clip_feature = pipe.image_encoder.encode_image([clip_feature]) return {"reference_latents": reference_latents, "clip_feature": clip_feature} @@ -732,7 +918,8 @@ class WanVideoUnit_FunCameraControl(PipelineUnit): input_image = input_image.resize((width, height)) input_latents = pipe.preprocess_video([input_image]) pipe.load_models_to_device(self.onload_model_names) - input_latents = pipe.vae.encode(input_latents, device=pipe.device) + with VAE_patch_parallel(): + input_latents = pipe.vae.encode(input_latents, device=pipe.device) y = torch.zeros_like(latents).to(pipe.device) y[:, :, :1] = input_latents y = y.to(dtype=pipe.torch_dtype, device=pipe.device) @@ -780,8 +967,9 @@ class WanVideoUnit_VACE(PipelineUnit): inactive = vace_video * (1 - vace_video_mask) + 0 * vace_video_mask reactive = vace_video * vace_video_mask + 0 * (1 - vace_video_mask) - inactive = pipe.vae.encode(inactive, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device) - reactive = pipe.vae.encode(reactive, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device) + with VAE_patch_parallel(): + inactive = pipe.vae.encode(inactive, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device) + reactive = pipe.vae.encode(reactive, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device) vace_video_latents = torch.concat((inactive, reactive), dim=1) vace_mask_latents = rearrange(vace_video_mask[0,0], "T (H P) (W Q) -> 1 (P Q) T H W", P=8, Q=8) @@ -791,7 +979,8 @@ class WanVideoUnit_VACE(PipelineUnit): pass else: vace_reference_image = pipe.preprocess_video([vace_reference_image]) - vace_reference_latents = pipe.vae.encode(vace_reference_image, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device) + with VAE_patch_parallel(): + vace_reference_latents = pipe.vae.encode(vace_reference_image, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device) vace_reference_latents = torch.concat((vace_reference_latents, torch.zeros_like(vace_reference_latents)), dim=1) vace_video_latents = torch.concat((vace_reference_latents, vace_video_latents), dim=2) vace_mask_latents = torch.concat((torch.zeros_like(vace_mask_latents[:, :, :1]), vace_mask_latents), dim=2) @@ -1010,9 +1199,16 @@ def model_fn_wan_video( if use_unified_sequence_parallel: import torch.distributed as dist - from xfuser.core.distributed import (get_sequence_parallel_rank, - get_sequence_parallel_world_size, - get_sp_group) + if torch.npu.is_available(): + from diffsynth.npu_utils.distributed.parallel_mgr import ( + get_sequence_parallel_rank, + get_sequence_parallel_world_size, + get_sp_group + ) + else: + from xfuser.core.distributed import (get_sequence_parallel_rank, + get_sequence_parallel_world_size, + get_sp_group) # Timestep if dit.seperated_timestep and fuse_vae_embedding_in_latents: @@ -1021,6 +1217,10 @@ def model_fn_wan_video( torch.ones((latents.shape[2] - 1, latents.shape[3] * latents.shape[4] // 4), dtype=latents.dtype, device=latents.device) * timestep ]).flatten() t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep).unsqueeze(0)) + if use_unified_sequence_parallel: + t_chunks = torch.chunk(t, get_sequence_parallel_world_size(), dim=1) + t_chunks = [torch.nn.functional.pad(chunk, (0, 0, 0, t_chunks[0].shape[1]-chunk.shape[1]), value=0) for chunk in t_chunks] + t = t_chunks[get_sequence_parallel_rank()] t_mod = dit.time_projection(t).unflatten(2, (6, dit.dim)) else: t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep)) @@ -1056,28 +1256,68 @@ def model_fn_wan_video( x = torch.concat([reference_latents, x], dim=1) f += 1 - freqs = torch.cat([ - dit.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1), - dit.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), - dit.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1) - ], dim=-1).reshape(f * h * w, 1, -1).to(x.device) + initialize_freqs_list = dit.freqs_list is None + + if initialize_freqs_list: + freqs = torch.cat([ + dit.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1), + dit.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), + dit.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1) + ], dim=-1).reshape(f * h * w, 1, -1).to(x.device) + dit.freqs_list = freqs + else: + freqs = dit.freqs_list # TeaCache if tea_cache is not None: tea_cache_update = tea_cache.check(dit, x, t_mod) else: tea_cache_update = False - - if vace_context is not None: - vace_hints = vace(x, vace_context, context, t_mod, freqs) + - # blocks - if use_unified_sequence_parallel: - if dist.is_initialized() and dist.get_world_size() > 1: + if torch.npu.is_available(): + # VACE 并行前置 + # blocks + if use_unified_sequence_parallel and dist.is_initialized() and dist.get_world_size() > 1: + sp_size = get_sequence_parallel_world_size() + sp_rank = get_sequence_parallel_rank() + chunks = torch.chunk(x, get_sequence_parallel_world_size(), dim=1) pad_shape = chunks[0].shape[1] - chunks[-1].shape[1] chunks = [torch.nn.functional.pad(chunk, (0, 0, 0, chunks[0].shape[1]-chunk.shape[1]), value=0) for chunk in chunks] - x = chunks[get_sequence_parallel_rank()] + x = chunks[sp_rank] + + if NPU_ROPE_AVAILABLE and initialize_freqs_list: + from diffsynth.distributed.xdit_context_parallel import pad_freqs + s = x.shape[1] + freqs_i = pad_freqs(freqs, s * sp_size) + s_per_rank = s + freqs_i_rank = freqs_i[(sp_rank * s_per_rank):((sp_rank + 1) * s_per_rank), :, :] + cos, sin = torch.chunk(torch.view_as_real(freqs_i_rank.to(torch.complex64)), 2, dim=-1) + cos = cos.unsqueeze(0).expand(-1, -1, -1, -1, 2).flatten(-2) + sin = sin.unsqueeze(0).expand(-1, -1, -1, -1, 2).flatten(-2) + freqs = (cos, sin) + dit.freqs_list = freqs + + if vace_context is not None: + vace_hints = vace(x, + vace_context=vace_context, + context=context, + t_mod=t_mod, + freqs=freqs) + else: + freqs = dit.freqs_list + + if vace_context is not None: + vace_hints = vace(x, vace_context, context, t_mod, freqs) + + # blocks + if use_unified_sequence_parallel: + if dist.is_initialized() and dist.get_world_size() > 1: + chunks = torch.chunk(x, get_sequence_parallel_world_size(), dim=1) + pad_shape = chunks[0].shape[1] - chunks[-1].shape[1] + chunks = [torch.nn.functional.pad(chunk, (0, 0, 0, chunks[0].shape[1]-chunk.shape[1]), value=0) for chunk in chunks] + x = chunks[get_sequence_parallel_rank()] if tea_cache_update: x = tea_cache.update(x) else: @@ -1104,7 +1344,7 @@ def model_fn_wan_video( x = block(x, context, t_mod, freqs) if vace_context is not None and block_id in vace.vace_layers_mapping: current_vace_hint = vace_hints[vace.vace_layers_mapping[block_id]] - if use_unified_sequence_parallel and dist.is_initialized() and dist.get_world_size() > 1: + if use_unified_sequence_parallel and dist.is_initialized() and dist.get_world_size() > 1 and not torch.npu.is_available(): current_vace_hint = torch.chunk(current_vace_hint, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()] current_vace_hint = torch.nn.functional.pad(current_vace_hint, (0, 0, 0, chunks[0].shape[1] - current_vace_hint.shape[1]), value=0) x = x + current_vace_hint * vace_scale @@ -1121,4 +1361,4 @@ def model_fn_wan_video( x = x[:, reference_latents.shape[1]:] f -= 1 x = dit.unpatchify(x, (f, h, w)) - return x + return x \ No newline at end of file diff --git a/diffsynth/utils/__init__.py b/diffsynth/utils/__init__.py index 97f3926411718d008c84fcde792b552915c5687d..13fc50fec1aaf7428c94c265539d91320534ced0 100644 --- a/diffsynth/utils/__init__.py +++ b/diffsynth/utils/__init__.py @@ -102,7 +102,8 @@ class BasePipeline(torch.nn.Module): module.offload() else: model.cpu() - torch.cuda.empty_cache() + if not torch.npu.is_available(): + torch.cuda.empty_cache() # onload models for name, model in self.named_children(): if name in model_names: diff --git a/examples/wanvideo/model_inference/Wan2.1-VACE-14B-Ascend.py b/examples/wanvideo/model_inference/Wan2.1-VACE-14B-Ascend.py new file mode 100644 index 0000000000000000000000000000000000000000..40138f742bf26e627c8172e57b3d557c0a0e7d49 --- /dev/null +++ b/examples/wanvideo/model_inference/Wan2.1-VACE-14B-Ascend.py @@ -0,0 +1,196 @@ +import os +import time +from datetime import datetime +import torch +import torch_npu +from torch_npu.contrib import transfer_to_npu +torch_npu.npu.set_compile_mode(jit_compile=False) +torch.npu.config.allow_internal_format=False + +from PIL import Image +from diffsynth import save_video, VideoData +from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig +from modelscope import dataset_snapshot_download + +import torch.distributed as dist + +model_path = "./models/" +wan2p1_vace_14b_path = "./models/Wan-AI/Wan2.1-VACE-14B" +# HEIGHT=480 +# WIDTH=832 +HEIGHT=432 +WIDTH=768 + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="npu", + model_configs=[ + ModelConfig( + model_id="Wan-AI/Wan2.1-VACE-14B", + origin_file_pattern="diffusion_pytorch_model*.safetensors", + local_model_path=model_path, + skip_download=True, + ), + ModelConfig( + path=[os.path.join(wan2p1_vace_14b_path, "models_t5_umt5-xxl-enc-bf16.pth")], + model_id="Wan-AI/Wan2.1-VACE-14B", + origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", + skip_download=True + ), + ModelConfig( + path=[os.path.join(wan2p1_vace_14b_path, "Wan2.1_VAE.pth")], + model_id="Wan-AI/Wan2.1-VACE-14B", + origin_file_pattern="Wan2.1_VAE.pth", + skip_download=True + ), + ], + use_usp=True, + vae_parallel=True +) + +if dist.get_rank() == 0: + dataset_snapshot_download( + dataset_id="DiffSynth-Studio/examples_in_diffsynth", + local_dir="./", + allow_file_pattern=["data/examples/wan/depth_video.mp4", "data/examples/wan/cat_fightning.jpg"] + ) + +torch.npu.synchronize() + +CONTROL_VIDEO = VideoData("data/examples/wan/depth_video.mp4", height=HEIGHT, width=WIDTH) +VACE_IMAGE = Image.open("data/examples/wan/cat_fightning.jpg").resize((WIDTH, HEIGHT)) +PROMPT = "两只可爱的橘猫戴上拳击手套,站在一个拳击台上搏斗。" +NEGATIVE_PROMPT = "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" + +def generate_iv2v(prompt, negative_prompt, vace_video, vace_reference_image, height, width): + if dist.get_rank() == 0: + print("========== [WARM UP] REFERENCE IMAGE + REFERENCE VIDEO -> VIDEO TEST ============") + + video = pipe( + prompt=prompt, + negative_prompt=negative_prompt, + vace_video=vace_video, + seed=0, tiled=False, + vace_reference_image=vace_reference_image, + num_inference_steps=2, + num_frames=81, + sigma_shift=16.0, + cfg_merge=True, + height=height, + width=width + ) + + if dist.get_rank() == 0: + print("========== [GENERATE] REFERENCE IMAGE + REFERENCE VIDEO -> VIDEO TEST ============") + + e2e_time = time.time() + video = pipe( + prompt=prompt, + negative_prompt=negative_prompt, + vace_video=vace_video, + seed=0, tiled=False, + vace_reference_image=vace_reference_image, + num_inference_steps=20, + num_frames=81, + sigma_shift=16.0, + cfg_merge=True, + height=HEIGHT, + width=WIDTH + ) + + if dist.get_rank() == 0: + save_time = time.time() + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + save_video(video, f"video1_14b_{timestamp}.mp4", fps=15, quality=5) + save_time = time.time() - save_time + print(f"MP4 Save time: {save_time}") + print(f"E2E time: {time.time() - e2e_time}") + +def generate_v2v(prompt, negative_prompt, vace_video, height, width): + if dist.get_rank() == 0: + print("========== [WARM UP] REFERENCE VIDEO -> VIDEO TEST ============") + video = pipe( + prompt=prompt, + negative_prompt=negative_prompt, + vace_video=vace_video, + seed=0, tiled=False, + num_inference_steps=2, + num_frames=81, + sigma_shift=16.0, + cfg_merge=True, + height=height, + width=width + ) + + if dist.get_rank() == 0: + print("========== [GENERATE] REFERENCE VIDEO -> VIDEO TEST ============") + + e2e_time = time.time() + video = pipe( + prompt=prompt, + negative_prompt=negative_prompt, + vace_video=vace_video, + seed=0, tiled=False, + num_inference_steps=20, + num_frames=81, + sigma_shift=16.0, + cfg_merge=True, + height=height, + width=width + ) + + if dist.get_rank() == 0: + save_time = time.time() + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + save_video(video, f"video1_14b_{timestamp}.mp4", fps=15, quality=5) + save_time = time.time() - save_time + print(f"MP4 Save time: {save_time}") + print(f"E2E time: {time.time() - e2e_time}") + + +def generate_i2v(prompt, negative_prompt, vace_reference_image, height, width): + if dist.get_rank() == 0: + print("========== [WARM UP] REFERENCE IMAGE -> VIDEO TEST ============") + + video = pipe( + prompt=prompt, + negative_prompt=negative_prompt, + seed=0, tiled=False, + vace_reference_image=vace_reference_image, + num_inference_steps=2, + num_frames=61, + sigma_shift=16.0, + cfg_merge=True, + height=height, + width=width + ) + + if dist.get_rank() == 0: + print("========== [GENERATE] REFERENCE IMAGE -> VIDEO TEST ============") + + e2e_time = time.time() + video = pipe( + prompt=prompt, + negative_prompt=negative_prompt, + seed=0, tiled=False, + vace_reference_image=vace_reference_image, + num_inference_steps=20, + num_frames=61, + sigma_shift=16.0, + cfg_merge=True, + height=height, + width=width + ) + + if dist.get_rank() == 0: + save_time = time.time() + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + save_video(video, f"video1_14b_{timestamp}.mp4", fps=15, quality=5) + save_time = time.time() - save_time + print(f"MP4 Save time: {save_time}") + print(f"E2E time: {time.time() - e2e_time}") + +if __name__ == "__main__": + generate_iv2v(PROMPT, NEGATIVE_PROMPT, CONTROL_VIDEO, VACE_IMAGE, HEIGHT, WIDTH) + generate_v2v(PROMPT, NEGATIVE_PROMPT, CONTROL_VIDEO, HEIGHT, WIDTH) + generate_i2v(PROMPT, NEGATIVE_PROMPT, VACE_IMAGE, HEIGHT, WIDTH) \ No newline at end of file diff --git a/examples/wanvideo/model_inference/Wan2.2-I2V-A14B-Ascend.py b/examples/wanvideo/model_inference/Wan2.2-I2V-A14B-Ascend.py new file mode 100644 index 0000000000000000000000000000000000000000..a09bc88f0bbd18ee7d2f68bc64319442e9f2c141 --- /dev/null +++ b/examples/wanvideo/model_inference/Wan2.2-I2V-A14B-Ascend.py @@ -0,0 +1,89 @@ +import os +import time +from datetime import datetime +import torch +import torch_npu +from torch_npu.contrib import transfer_to_npu +torch_npu.npu.set_compile_mode(jit_compile=False) +torch.npu.config.allow_internal_format=False + +from PIL import Image +from diffsynth import save_video +from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig +from modelscope import dataset_snapshot_download + +model_path = './models/' + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device=f"npu:{os.getenv('RANK', 0)}", + model_configs=[ + ModelConfig( + model_id="Wan-AI/Wan2.2-I2V-A14B", + origin_file_pattern="high_noise_model/diffusion_pytorch_model*.safetensors", + local_model_path=model_path, + skip_download=True, + offload_device="cpu" + ), + ModelConfig( + model_id="Wan-AI/Wan2.2-I2V-A14B", + origin_file_pattern="low_noise_model/diffusion_pytorch_model*.safetensors", + local_model_path=model_path, + skip_download=True, + offload_device="cpu" + ), + ModelConfig( + model_id="Wan-AI/Wan2.2-I2V-A14B", + origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", + path=[os.path.join(model_path, "Wan-AI/Wan2.2-I2V-A14B/models_t5_umt5-xxl-enc-bf16.pth")], + skip_download=True, + # offload_device="npu" + ), + ModelConfig( + model_id="Wan-AI/Wan2.2-I2V-A14B", + origin_file_pattern="Wan2.1_VAE.pth", + path=[os.path.join(model_path, "Wan-AI/Wan2.2-I2V-A14B/Wan2.1_VAE.pth")], + skip_download=True, + # offload_device="cpu" + ), + ], + use_usp=True, + vae_parallel=True, +) +pipe.enable_vram_management(num_persistent_param_in_dit=50*10**9) + +if int(os.getenv("RANK"), 0) == 0: + dataset_snapshot_download( + dataset_id="DiffSynth-Studio/examples_in_diffsynth", + local_dir="./", + allow_file_pattern=["data/examples/wan/cat_fightning.jpg"] + ) +input_image = Image.open("data/examples/wan/cat_fightning.jpg").resize((832, 480)) + +video = pipe( + prompt="Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + seed=0, tiled=False, + num_inference_steps=10, + num_frames=81, + input_image=input_image, +) + +start_time = time.time() +video = pipe( + prompt="Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + seed=0, tiled=False, + num_inference_steps=40, + num_frames=81, + input_image=input_image, +) + +if int(os.getenv("RANK"), 0) == 0: + save_time = time.time() + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + save_video(video, f"wan2_2_video1_14b_{timestamp}.mp4", fps=15, quality=5) + save_time = time.time() - save_time + print(f"MP4 Save time: {save_time}") + print(f"E2E time: {time.time() - start_time}") + diff --git a/examples/wanvideo/model_inference/Wan2.2-TI2V-5B-Ascend.py b/examples/wanvideo/model_inference/Wan2.2-TI2V-5B-Ascend.py new file mode 100644 index 0000000000000000000000000000000000000000..d918ceb8767984636f9a180dea0849a7b475041e --- /dev/null +++ b/examples/wanvideo/model_inference/Wan2.2-TI2V-5B-Ascend.py @@ -0,0 +1,83 @@ +import os +import time +from datetime import datetime +import torch +import torch_npu +from torch_npu.contrib import transfer_to_npu +torch_npu.npu.set_compile_mode(jit_compile=False) +torch.npu.config.allow_internal_format=False + +from PIL import Image +from diffsynth import save_video +from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig +from modelscope import dataset_snapshot_download + +model_path = './models/' + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="npu", + model_configs=[ + ModelConfig( + model_id="Wan-AI/Wan2.2-TI2V-5B", + origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", + path=[os.path.join(model_path, "Wan-AI/Wan2.2-TI2V-5B/models_t5_umt5-xxl-enc-bf16.pth")], + skip_download=True, + # offload_device="cpu" + ), + ModelConfig( + model_id="Wan-AI/Wan2.2-TI2V-5B", + origin_file_pattern="diffusion_pytorch_model*.safetensors", + local_model_path=model_path, + skip_download=True, + # offload_device="cpu" + ), + ModelConfig( + model_id="Wan-AI/Wan2.2-TI2V-5B", + origin_file_pattern="Wan2.2_VAE.pth", + path=[os.path.join(model_path, "Wan-AI/Wan2.2-TI2V-5B/Wan2.2_VAE.pth")], + skip_download=True, + # offload_device="cpu" + ), + ], + use_usp=True, + vae_parallel=True, +) + +# Image-to-video +if int(os.getenv("RANK"), 0) == 0: + dataset_snapshot_download( + dataset_id="DiffSynth-Studio/examples_in_diffsynth", + local_dir="./", + allow_file_pattern=["data/examples/wan/cat_fightning.jpg"] + ) +input_image = Image.open("data/examples/wan/cat_fightning.jpg") + +torch.npu.synchronize() + +video = pipe( + prompt="两只可爱的橘猫戴上拳击手套,站在一个拳击台上搏斗。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + seed=42, tiled=False, + height=704, width=1280, + input_image=input_image, + num_frames=121, + num_inference_steps=2, +) +start_time = time.time() +video = pipe( + prompt="两只可爱的橘猫戴上拳击手套,站在一个拳击台上搏斗。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + seed=42, tiled=False, + height=704, width=1280, + input_image=input_image, + num_frames=121, + num_inference_steps=40, +) +if int(os.getenv("RANK"), 0) == 0: + save_time = time.time() + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + save_video(video, f"wan2_2_video1_5b_{timestamp}.mp4", fps=15, quality=5) + save_time = time.time() - save_time + print(f"MP4 Save time: {save_time}") + print(f"E2E time: {time.time() - start_time}") diff --git a/examples/wanvideo/model_inference/start_test.sh b/examples/wanvideo/model_inference/start_test.sh new file mode 100644 index 0000000000000000000000000000000000000000..6ea3b4332bc49dbf309a55adace3fabb3de03508 --- /dev/null +++ b/examples/wanvideo/model_inference/start_test.sh @@ -0,0 +1,20 @@ +sysctl -w net.ipv4.ip_local_reserved_ports=50000-50015 +if [ "$(uname -m)" = "aarch64" ]; then + export CPLUS_INCLUDE_PATH=/usr/include/c++/12/:/usr/include/c++/12/aarch64-openEuler-linux/:$CPLUS_INCLUDE_PATH + export LD_PRELOAD=/usr/lib64/libjemalloc.so.2:$LD_PRELOAD +fi +export ALGO=1 +export PYTORCH_NPU_ALLOC_CONF='expandable_segments:True' +export TASK_QUEUE_ENABLE=2 + +export CPU_AFFINITY_CONF=1 +export TOKENIZERS_PARALLELISM=false +export ASCEND_LAUNCH_BLOCKING=0 +export PROFILING_ENABLE=0 +export PROFILING_DIR=./prof_wan2.1_vace_14b_ascend_$(date +%Y%m%d-%H%M%S) +torchrun --standalone --nproc_per_node=8 Wan2.1-VACE-14B-Ascend.py +# torchrun --standalone --nproc_per_node=8 Wan2.2-I2V-A14B-Ascend.py + +if [ $PROFILING_ENABLE -eq 1 ]; then + tar -czvf ${PROFILING_DIR}.tar.gz ${PROFILING_DIR} +fi \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 889b7fa7fa8741d80452bc202ffffe9a55f7ba67..dbcdc638d9a038de6ca4f3bc28686d429b8486f1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,7 @@ -torch>=2.0.0 -torchvision -cupy-cuda12x -transformers +torch==2.1.0 +torchvision==0.16.0 +# cupy-cuda12x +transformers==4.51.0 controlnet-aux==0.0.7 imageio imageio[ffmpeg] @@ -14,3 +14,14 @@ ftfy pynvml pandas accelerate +gradio>=5.0.0 +numpy>=1.23.5,<2 +yunchang==0.6.0 +opencv-python-headless +strenum +easydict +tqdm +dashscope +ftfy +imageio-ffmpeg +datasets \ No newline at end of file