diff --git a/examples/stable_diffusion_3_boost/infer/attention_processor_cache.py b/examples/stable_diffusion_3_boost/infer/attention_processor_cache.py new file mode 100644 index 0000000000000000000000000000000000000000..7cec200f0115ab6cf2aa31ab490c765cc9ccd690 --- /dev/null +++ b/examples/stable_diffusion_3_boost/infer/attention_processor_cache.py @@ -0,0 +1,137 @@ +from typing import Optional +import math +import mindspore as ms +from mindspore import ops + +from mindone.diffusers.models.attention_processor import Attention +from mindone.diffusers.models.attention_processor import JointAttnProcessor +from mindone.diffusers.models.attention import _chunked_feed_forward + + +def joint_transformerblock_construct(self, hidden_states: ms.Tensor, encoder_hidden_states: ms.Tensor, temb: ms.Tensor): + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb) + norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = (None,) * 5 + if self.context_pre_only: + norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states, temb) + else: + norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context( + encoder_hidden_states, emb=temb + ) + # Attention. + attn_output, context_attn_output = self.attn( + hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states + ) + # Process attention outputs for the `hidden_states`. + attn_output = gate_msa.unsqueeze(1) * attn_output + hidden_states = hidden_states + attn_output + norm_hidden_states = self.norm2(hidden_states) + norm_hidden_states = norm_hidden_states * (1 + scale_mlp.expand_dims(1)) + shift_mlp.expand_dims(1) + if self._chunk_size is not None: + # "feed_forward_chunk_size" can be used to save memory + ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size) + else: + ff_output = self.ff(norm_hidden_states) + ff_output = gate_mlp.unsqueeze(1) * ff_output + hidden_states = hidden_states + ff_output + # Process attention outputs for the `encoder_hidden_states`. + if self.context_pre_only: + encoder_hidden_states = None + else: + context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output + encoder_hidden_states = encoder_hidden_states + context_attn_output + norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) + norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp.expand_dims(1)) + c_shift_mlp.expand_dims(1) + if self._chunk_size is not None: + # "feed_forward_chunk_size" can be used to save memory + context_ff_output = _chunked_feed_forward( + self.ff_context, norm_encoder_hidden_states, self._chunk_dim, self._chunk_size + ) + else: + context_ff_output = self.ff_context(norm_encoder_hidden_states) + encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output + return encoder_hidden_states, hidden_states + + +def downsample(hidden_states, merge_factor, method='nearest'): + batch_size, _, channel = hidden_states.shape + cur_h = int(math.sqrt(hidden_states.shape[1])) + cur_w = cur_h + new_h, new_w = int(cur_h / merge_factor), int(cur_w / merge_factor) + hidden_states = hidden_states.permute(0, 2, 1).reshape(batch_size, channel, cur_h, cur_w) + merged_hidden_states = ops.interpolate(hidden_states, size=(new_h, new_w), mode=method) + merged_hidden_states = merged_hidden_states.permute(0, 2, 3, 1).reshape(batch_size, -1, channel) + return merged_hidden_states + + +@ms.jit_class +class ToDoJointAttnProcessor(JointAttnProcessor): + """Attention processor used typically in processing the SD3-like self-attention projections.""" + + def __call__( + self, + attn: Attention, + hidden_states: ms.Tensor, + encoder_hidden_states: ms.Tensor = None, + attention_mask: Optional[ms.Tensor] = None, + ) -> ms.Tensor: + residual = hidden_states + + batch_size, channel, height, width = (None,) * 4 + input_ndim = hidden_states.ndim + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).swapaxes(1, 2) + context_input_ndim = encoder_hidden_states.ndim + if context_input_ndim == 4: + batch_size, channel, height, width = encoder_hidden_states.shape + encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).swapaxes(1, 2) + + batch_size = encoder_hidden_states.shape[0] + + # `sample` projections. + query = attn.to_q(hidden_states) + if attn.use_downsample and attn.layer_idx <= 11: + hidden_states = downsample(hidden_states, attn.token_merge_factor, method=attn.token_merge_method) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + + # `context` projections. + encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) + encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) + encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) + + # attention + query = ops.cat([query, encoder_hidden_states_query_proj], axis=1) + key = ops.cat([key, encoder_hidden_states_key_proj], axis=1) + value = ops.cat([value, encoder_hidden_states_value_proj], axis=1) + + query = attn.head_to_batch_dim(query) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + + hidden_states = ops.operations.nn_ops.FlashAttentionScore(1, scale_value=attn.scale)( + query.to(ms.float16), key.to(ms.float16), value.to(ms.float16), None, None, None, attention_mask + )[3].to(query.dtype) + hidden_states = attn.batch_to_head_dim(hidden_states) + hidden_states = hidden_states.to(query.dtype) + + # Split the attention outputs. + hidden_states, encoder_hidden_states = ( + hidden_states[:, : residual.shape[1]], + hidden_states[:, residual.shape[1] :], + ) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + if not attn.context_pre_only: + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.swapaxes(-1, -2).reshape(batch_size, channel, height, width) + if context_input_ndim == 4: + encoder_hidden_states = encoder_hidden_states.swapaxes(-1, -2).reshape(batch_size, channel, height, width) + + return hidden_states, encoder_hidden_states + diff --git a/examples/stable_diffusion_3_boost/infer/embeddings_replace.py b/examples/stable_diffusion_3_boost/infer/embeddings_replace.py new file mode 100644 index 0000000000000000000000000000000000000000..481ef1694f9cb44bf15ba657050ac1e9f3dbe22e --- /dev/null +++ b/examples/stable_diffusion_3_boost/infer/embeddings_replace.py @@ -0,0 +1,45 @@ +import math + +import mindspore as ms +from mindspore import nn, ops, mint + +def get_timestep_embedding( + timesteps: ms.Tensor, + embedding_dim: int, + flip_sin_to_cos: bool = False, + downscale_freq_shift: float = 1, + scale: float = 1, + max_period: int = 10000, +): + """ + This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings. + + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the + embeddings. :return: an [N x dim] Tensor of positional embeddings. + """ + assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array" + + half_dim = embedding_dim // 2 + exponent = -ops.log(ms.Tensor(max_period, dtype=ms.float32)) * ops.arange(start=0, end=half_dim, dtype=ms.float32) + exponent = mint.div(exponent, (half_dim - downscale_freq_shift)) + + emb = ops.exp(exponent) + emb = timesteps.expand_dims(1).float() * emb.expand_dims(0) + + # scale embeddings + emb = scale * emb + + # concat sine and cosine embeddings + emb = ops.cat([ops.sin(emb), ops.cos(emb)], axis=-1) + + # flip sine and cosine embeddings + if flip_sin_to_cos: + sin, cos = mint.split(emb, half_dim, dim=1) + emb = ops.cat((cos, sin), axis=-1) + + # zero pad + if embedding_dim % 2 == 1: + emb = ops.pad(emb, (0, 1, 0, 0)) + return emb diff --git a/examples/stable_diffusion_3_boost/infer/modeling_clip_replace.py b/examples/stable_diffusion_3_boost/infer/modeling_clip_replace.py new file mode 100644 index 0000000000000000000000000000000000000000..6a6bad865f81c30f9720c2946e117549ac2bc44d --- /dev/null +++ b/examples/stable_diffusion_3_boost/infer/modeling_clip_replace.py @@ -0,0 +1,56 @@ +from typing import Optional, Tuple + +import mindspore as ms +from mindspore import ops +from mindspore.ops.operations.nn_ops import FlashAttentionScore + + +def clip_attention_construct( + self, + hidden_states: ms.Tensor, + attention_mask: Optional[ms.Tensor] = None, + causal_attention_mask: Optional[ms.Tensor] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[ms.Tensor, Optional[ms.Tensor], Optional[Tuple[ms.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + bsz, q_len, _ = hidden_states.shape + + # get query proj + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + src_len = key_states.shape[1] + attn_mask = None + # apply the causal_attention_mask first + if causal_attention_mask is not None: + if causal_attention_mask.shape != (bsz, 1, q_len, src_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, src_len)}, but is" + f" {causal_attention_mask.shape}" + ) + attn_mask = causal_attention_mask if attention_mask is None else None + + if attention_mask is not None: + if attention_mask.shape != (bsz, 1, q_len, src_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, src_len)}, but is {attention_mask.shape}" + ) + attn_mask = attention_mask if causal_attention_mask is None else causal_attention_mask + attention_mask + + attn_mask = ops.cast(attn_mask, dtype=ms.bool_) if attn_mask is not None else None + + _, _, softmax_out, attn_output = FlashAttentionScore( + self.num_heads, + keep_prob=1.0, + sparse_mode=0, + scale_value=self.scale, + input_layout='BSH' + )(query_states, key_states, value_states, None, None, None, attn_mask, None) + + attn_weights_reshaped = None if not output_attentions else softmax_out + + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights_reshaped diff --git a/examples/stable_diffusion_3_boost/infer/modeling_t5_replace.py b/examples/stable_diffusion_3_boost/infer/modeling_t5_replace.py new file mode 100644 index 0000000000000000000000000000000000000000..47298e1fe637e54e016318ebfb02d70ea7ed2c32 --- /dev/null +++ b/examples/stable_diffusion_3_boost/infer/modeling_t5_replace.py @@ -0,0 +1,41 @@ +import mindspore as ms +from mindspore import Tensor, ops, mint + + +def t5_layernorm_construct(self, hidden_states): + variance = mint.mean(hidden_states.to(ms.float32).pow(2), dim=-1, keepdim=True) + hidden_states = hidden_states * ops.rsqrt(variance + self.variance_epsilon) + + # convert into half-precision if necessary + if self.weight.dtype in [ms.float16, ms.bfloat16]: + hidden_states = hidden_states.to(self.weight.dtype) + + return self.weight * hidden_states + + +def t5_attention_relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128): + relative_buckets = 0 + if bidirectional: + num_buckets //= 2 + relative_buckets += (relative_position > 0).long() * num_buckets + relative_position = ops.abs(relative_position) + else: + relative_position = -ops.minimum(relative_position, ops.zeros_like(relative_position)) + # now relative_position is in the range [0, inf) + # half of the buckets are for exact increments in positions + max_exact = num_buckets // 2 + is_small = relative_position < max_exact + # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance + relative_position_if_large = ( + max_exact + + (mint.div( + ops.log(mint.div(relative_position.float(), max_exact)) + , ops.log(Tensor(max_distance, ms.float32) / max_exact)) + * (num_buckets - max_exact) + ).long() + ) + relative_position_if_large = ops.minimum( + relative_position_if_large, ops.full_like(relative_position_if_large, num_buckets - 1) + ) + relative_buckets += ops.where(is_small, relative_position, relative_position_if_large) + return relative_buckets diff --git a/examples/stable_diffusion_3_boost/infer/normalization_replace.py b/examples/stable_diffusion_3_boost/infer/normalization_replace.py new file mode 100644 index 0000000000000000000000000000000000000000..91639e5d54e00b3815180b175533ca54fd5da28b --- /dev/null +++ b/examples/stable_diffusion_3_boost/infer/normalization_replace.py @@ -0,0 +1,47 @@ +from typing import Optional, Tuple + +import mindspore as ms +from mindspore import ops, mint + +from mindone.diffusers.models.layers_compat import group_norm + +def ada_layernorm_construct(self, x: ms.Tensor, timestep: ms.Tensor) -> ms.Tensor: + # Argument 'timestep' is a 0-dim tensor, we will unsqueezed it firstly + # because inputs tensor of nn.Dense should has more than 1 dim. + emb = self.linear(self.silu(self.emb(timestep[None]))) + scale, shift = mint.chunk(emb, 2, dim=1) + x = self.norm(x) * (1 + scale.expand_dims(1)) + shift.expand_dims(1) + return x + +def ada_layernormzero_construct( + self, + x: ms.Tensor, + timestep: Optional[ms.Tensor] = None, + class_labels: Optional[ms.Tensor] = None, + hidden_dtype=None, + emb: Optional[ms.Tensor] = None, + ) -> Tuple[ms.Tensor, ms.Tensor, ms.Tensor, ms.Tensor, ms.Tensor]: + if self.emb is not None: + emb = self.emb(timestep, class_labels, hidden_dtype=hidden_dtype) + emb = self.linear(self.silu(emb)) + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mint.chunk(emb, 6, dim=1) + x = self.norm(x) * (1 + scale_msa.expand_dims(1)) + shift_msa.expand_dims(1) + return x, gate_msa, shift_mlp, scale_mlp, gate_mlp + +def ada_groupnorm_construct(self, x: ms.Tensor, emb: ms.Tensor) -> ms.Tensor: + if self.act: + emb = self.act(emb) + emb = self.linear(emb) + emb = emb.expand_dims(2).expand_dims(2) + + scale, shift = mint.chunk(emb, 2, dim=1) + x = group_norm(x, self.num_groups, None, None, self.eps) + x = x * (1 + scale) + shift + return x + +def ada_layernorm_continuous_construct(self, x: ms.Tensor, conditioning_embedding: ms.Tensor) -> ms.Tensor: + # convert back to the original dtype in case `conditioning_embedding`` is upcasted to float32 (needed for hunyuanDiT) + emb = self.linear(self.silu(conditioning_embedding).to(x.dtype)) + scale, shift = ops.chunk(emb, 2, axis=1) + x = self.norm(x) * (1 + scale).expand_dims(1) + shift.expand_dims(1) + return x diff --git a/examples/stable_diffusion_3_boost/infer/transformer_sd3_cache.py b/examples/stable_diffusion_3_boost/infer/transformer_sd3_cache.py new file mode 100644 index 0000000000000000000000000000000000000000..2619b64df1dfe4cce54d88998595005ed9e11bb0 --- /dev/null +++ b/examples/stable_diffusion_3_boost/infer/transformer_sd3_cache.py @@ -0,0 +1,186 @@ +from typing import Any, Dict, List, Tuple, Optional, Union +import mindspore as ms + +from mindone.diffusers.models.transformers.transformer_2d import Transformer2DModelOutput + + +def sd3_transformer2d_construct( + self, + hidden_states: ms.Tensor, + encoder_hidden_states: ms.Tensor = None, + pooled_projections: ms.Tensor = None, + timestep: ms.Tensor = None, + block_controlnet_hidden_states: List = None, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + return_dict: bool = False, + cache_params: Tuple = None, + if_skip: bool = False, + delta_cache: ms.Tensor = None, + delta_cache_hidden: ms.Tensor = None, + use_cache: bool = False, + ) -> Union[ms.Tensor, Transformer2DModelOutput, Tuple]: + """ + The [`SD3Transformer2DModel`] forward method. + Args: + hidden_states (`ms.Tensor` of shape `(batch size, channel, height, width)`): + Input `hidden_states`. + encoder_hidden_states (`ms.Tensor` of shape `(batch size, sequence_len, embed_dims)`): + Conditional embeddings (embeddings computed from the input conditions such as prompts) to use. + pooled_projections (`ms.Tensor` of shape `(batch_size, projection_dim)`): Embeddings projected + from the embeddings of input conditions. + timestep ( `ms.Tensor`): + Used to indicate denoising step. + block_controlnet_hidden_states: (`list` of `mindspore.Tensor`): + A list of tensors that if specified are added to the residuals of transformer blocks. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + return_dict (`bool`, *optional*, defaults to `False`): + Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain + tuple. + cache_params (`Tuple`): + A tuple of cache parameters which contains start cache layer id, step_stride, use cache layer nums, start use cache step + Returns: + If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a + `tuple` where the first element is the sample tensor. + """ + if joint_attention_kwargs is not None and "scale" in joint_attention_kwargs: + # weight the lora layers by setting `lora_scale` for each PEFT layer here + # and remove `lora_scale` from each PEFT layer at the end. + # scale_lora_layers & unscale_lora_layers maybe contains some operation forbidden in graph mode + raise RuntimeError( + f"You are trying to set scaling of lora layer by passing {joint_attention_kwargs['scale']=}. " + f"However it's not allowed in on-the-fly model forwarding. " + f"Please manually call `scale_lora_layers(model, lora_scale)` before model forwarding and " + f"`unscale_lora_layers(model, lora_scale)` after model forwarding. " + f"For example, it can be done in a pipeline call like `StableDiffusionPipeline.__call__`." + ) + height, width = hidden_states.shape[-2:] + hidden_states = self.pos_embed(hidden_states) # takes care of adding positional embeddings too. + temb = self.time_text_embed(timestep, pooled_projections) + encoder_hidden_states = self.context_embedder(encoder_hidden_states) + ( + (encoder_hidden_states, hidden_states), + delta_cache, + delta_cache_hidden + ) = self.forward_blocks( + hidden_states, + encoder_hidden_states, + block_controlnet_hidden_states, + temb, + use_cache, + if_skip, + cache_params, + delta_cache, + delta_cache_hidden, + ) + hidden_states = self.norm_out(hidden_states, temb) + hidden_states = self.proj_out(hidden_states) + # unpatchify + patch_size = self.config["patch_size"] + height = height // patch_size + width = width // patch_size + hidden_states = hidden_states.reshape( + hidden_states.shape[0], + height, + width, + patch_size, + patch_size, + self.out_channels, + ) + # hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states) + hidden_states = hidden_states.transpose(0, 5, 1, 3, 2, 4) + output = hidden_states.reshape( + hidden_states.shape[0], self.out_channels, height * patch_size, width * patch_size + ) + if not return_dict: + return (output, delta_cache, delta_cache_hidden) if use_cache else (output,) + return Transformer2DModelOutput(sample=output) + + +def forward_blocks_range( + self, + hidden_states, + encoder_hidden_states, + block_controlnet_hidden_states, + temb, + start_idx, + end_idx, +): + for index_block, block in enumerate(self.transformer_blocks[start_idx:end_idx]): + encoder_hidden_states, hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=temb, + ) + # controlnet residual + if block_controlnet_hidden_states is not None and block.context_pre_only is False: + interval_control = len(self.transformer_blocks) // len(block_controlnet_hidden_states) + hidden_states = hidden_states + block_controlnet_hidden_states[index_block // interval_control] + + return hidden_states, encoder_hidden_states + + +def forward_blocks( + self, + hidden_states, + encoder_hidden_states, + block_controlnet_hidden_states, + temb, + use_cache, + if_skip, + cache_params, + delta_cache, + delta_cache_hidden, +): + if not use_cache: + hidden_states, encoder_hidden_states = self.forward_blocks_range( + hidden_states, + encoder_hidden_states, + block_controlnet_hidden_states, + temb, + start_idx=0, + end_idx=len(self.transformer_blocks) + ) + else: + # infer [0, cache_start) + hidden_states, encoder_hidden_states = self.forward_blocks_range( + hidden_states, + encoder_hidden_states, + block_controlnet_hidden_states, + temb, + start_idx=0, + end_idx=cache_params[0], + ) + # infer [cache_start, cache_end) + cache_end = cache_params[0] + cache_params[2] + hidden_states_before_cache = hidden_states.copy() + encoder_hidden_states_before_cache = encoder_hidden_states.copy() + if not if_skip: + hidden_states, encoder_hidden_states = self.forward_blocks_range( + hidden_states, + encoder_hidden_states, + block_controlnet_hidden_states, + temb, + start_idx=cache_params[0], + end_idx=cache_end, + ) + delta_cache = hidden_states - hidden_states_before_cache + delta_cache_hidden = encoder_hidden_states - encoder_hidden_states_before_cache + else: + hidden_states = hidden_states_before_cache + delta_cache + encoder_hidden_states = encoder_hidden_states_before_cache + delta_cache_hidden + + # infer [cache_end, len(self.blocks)) + hidden_states, encoder_hidden_states = self.forward_blocks_range( + hidden_states, + encoder_hidden_states, + block_controlnet_hidden_states, + temb, + start_idx=cache_end, + end_idx=len(self.transformer_blocks), + ) + return (encoder_hidden_states, hidden_states), delta_cache, delta_cache_hidden + + diff --git a/examples/stable_diffusion_3_boost/pipeline_stable_diffusion_3_boost.py b/examples/stable_diffusion_3_boost/pipeline_stable_diffusion_3_boost.py new file mode 100644 index 0000000000000000000000000000000000000000..75243b0be42056bde8a1525fb921dd841200d7d6 --- /dev/null +++ b/examples/stable_diffusion_3_boost/pipeline_stable_diffusion_3_boost.py @@ -0,0 +1,419 @@ +from typing import Any, Callable, Dict, List, Optional, Union + +import mindone.diffusers +import mindone.transformers +import mindone.transformers.models +import numpy as np +from transformers import CLIPTokenizer, T5TokenizerFast +import mindspore as ms +from mindspore import ops, mint +from mindspore.common.api import _pynative_executor as ms_pyexecutor + +import mindone +from mindone.diffusers.models import SD3Transformer2DModel +from mindone.diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL +from mindone.diffusers.schedulers.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler +from mindone.transformers.models.clip.modeling_clip import CLIPTextModelWithProjection +from mindone.transformers.models.t5.modeling_t5 import T5EncoderModel +from mindone.diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3 import StableDiffusion3Pipeline, retrieve_timesteps +from mindone.diffusers.utils.peft_utils import scale_lora_layers, unscale_lora_layers +from mindone.diffusers.pipelines.stable_diffusion_3.pipeline_output import StableDiffusion3PipelineOutput +from .infer.transformer_sd3_cache import sd3_transformer2d_construct, forward_blocks, forward_blocks_range +from .infer.attention_processor_cache import ToDoJointAttnProcessor, joint_transformerblock_construct +from .infer.normalization_replace import ada_groupnorm_construct, ada_layernorm_construct, ada_layernorm_continuous_construct, ada_layernormzero_construct +from .infer.modeling_t5_replace import t5_layernorm_construct, t5_attention_relative_position_bucket +from .infer.embeddings_replace import get_timestep_embedding +from .infer.modeling_clip_replace import clip_attention_construct + +mindone.diffusers.models.SD3Transformer2DModel.construct=sd3_transformer2d_construct +mindone.diffusers.models.SD3Transformer2DModel.forward_blocks=forward_blocks +mindone.diffusers.models.SD3Transformer2DModel.forward_blocks_range=forward_blocks_range +mindone.diffusers.models.attention.JointTransformerBlock.construct=joint_transformerblock_construct +mindone.diffusers.models.normalization.AdaLayerNorm.construct=ada_layernorm_construct +mindone.diffusers.models.normalization.AdaLayerNormZero.construct=ada_layernormzero_construct +mindone.diffusers.models.normalization.AdaLayerNormContinuous.construct=ada_layernorm_continuous_construct +mindone.diffusers.models.normalization.AdaGroupNorm.construct=ada_groupnorm_construct +mindone.diffusers.models.embeddings.get_timestep_embedding=get_timestep_embedding +mindone.transformers.models.t5.modeling_t5.T5LayerNorm.construct=t5_layernorm_construct +mindone.transformers.models.t5.modeling_t5.T5Attention._relative_position_bucket=t5_attention_relative_position_bucket +mindone.transformers.models.clip.modeling_clip.CLIPAttention.construct=clip_attention_construct + +class StableDiffusion3PipelineBoost(StableDiffusion3Pipeline): + def __init__( + self, + transformer: SD3Transformer2DModel, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKL, + text_encoder: CLIPTextModelWithProjection, + tokenizer: CLIPTokenizer, + text_encoder_2: CLIPTextModelWithProjection, + tokenizer_2: CLIPTokenizer, + text_encoder_3: T5EncoderModel, + tokenizer_3: T5TokenizerFast + ): + super().__init__( + transformer, + scheduler, + vae, + text_encoder, + tokenizer, + text_encoder_2, + tokenizer_2, + text_encoder_3, + tokenizer_3 + ) + + self.tgate = 20 + self.use_cache_and_tgate = False + self.cache_params = (1, 2, 20, 10) + self.token_merge_factor = 1.6 + self.token_merge_method = "bilinear" + self.use_todo = False + self.init_todo_processor = False + + def _enable_boost(self, use_cache_and_tgate: bool = True, use_todo: bool = False): + if not self.init_todo_processor: + self.init_todo_processor = True + self.transformer.set_attn_processor(ToDoJointAttnProcessor()) + for block_idx, transformer_block in enumerate(self.transformer.transformer_blocks): + transformer_block.attn.use_downsample = use_todo + transformer_block.attn.layer_idx = block_idx + transformer_block.attn.token_merge_factor = self.token_merge_factor + transformer_block.attn.token_merge_method = self.token_merge_method + if self.use_todo != use_todo: + for transformer_block in self.transformer.transformer_blocks: + transformer_block.attn.use_downsample = use_todo + self.use_todo = use_todo + self.use_cache_and_tgate = use_cache_and_tgate \ + if self.use_cache_and_tgate != use_cache_and_tgate else self.use_cache_and_tgate + + def __call__( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + prompt_3: Optional[Union[str, List[str]]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 28, + timesteps: List[int] = None, + guidance_scale: float = 7.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + negative_prompt_3: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[Union[np.random.Generator, List[np.random.Generator]]] = None, + latents: Optional[ms.Tensor] = None, + prompt_embeds: Optional[ms.Tensor] = None, + negative_prompt_embeds: Optional[ms.Tensor] = None, + pooled_prompt_embeds: Optional[ms.Tensor] = None, + negative_pooled_prompt_embeds: Optional[ms.Tensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = False, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + clip_skip: Optional[int] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 77, + use_cache_and_tgate: bool = False, + use_todo: bool = False, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + will be used instead + prompt_3 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to `tokenizer_3` and `text_encoder_3`. If not defined, `prompt` is + will be used instead + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 5.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used instead + negative_prompt_3 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_3` and + `text_encoder_3`. If not defined, `negative_prompt` is used instead + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`np.random.Generator` or `List[np.random.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`ms.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`ms.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`ms.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`ms.Tensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`ms.Tensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `False`): + Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead + of a plain tuple. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a + `tuple`. When returning a tuple, the first element is a list with the generated images. + """ + + self._enable_boost(use_cache_and_tgate=use_cache_and_tgate, use_todo=use_todo) + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + prompt_3, + height, + width, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + negative_prompt_3=negative_prompt_3, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + max_sequence_length=max_sequence_length, + ) + + self._guidance_scale = guidance_scale + self._clip_skip = clip_skip + self._joint_attention_kwargs = joint_attention_kwargs + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + prompt_3=prompt_3, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + negative_prompt_3=negative_prompt_3, + do_classifier_free_guidance=self.do_classifier_free_guidance, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + clip_skip=self.clip_skip, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + ) + + if self.use_cache_and_tgate: + prompt_embeds_origin = prompt_embeds.copy() + pooled_prompt_embeds_origin = pooled_prompt_embeds.copy() + if self.do_classifier_free_guidance: + prompt_embeds = ops.cat([negative_prompt_embeds, prompt_embeds], axis=0) + pooled_prompt_embeds = ops.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], axis=0) + + # 4. Prepare timesteps + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, timesteps) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # 5. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + generator, + latents, + ) + + # we're popping the `scale` instead of getting it because otherwise `scale` will be propagated + # to the transformer and will raise RuntimeError. + lora_scale = self.joint_attention_kwargs.pop("scale", None) if self.joint_attention_kwargs is not None else None + if lora_scale is not None: + # weight the lora layers by setting `lora_scale` for each PEFT layer + scale_lora_layers(self.transformer, lora_scale) + + if self.use_cache_and_tgate: + latents_height, latents_width = latents.shape[-2:] + patchs_num = (latents_height // self.transformer.config.patch_size) * (latents_width // self.transformer.config.patch_size) + delta_cache = ops.zeros([2, patchs_num, 1536], dtype=ms.float16) + delta_cache_hidden = ops.zeros([2, max_sequence_length + 77, 1536], dtype=ms.float16) + + cache_interval = self.cache_params[1] + step_constrast = self.cache_params[3] % 2 + + # 6. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + if not self.use_cache_and_tgate: + # expand the latents if we are doing classifier free guidance + latent_model_input = ops.cat([latents] * 2) if self.do_classifier_free_guidance else latents + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.broadcast_to((latent_model_input.shape[0],)) + + ms_pyexecutor.sync() + ms_pyexecutor.set_async_for_graph(False) + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + pooled_projections=pooled_prompt_embeds, + joint_attention_kwargs=self.joint_attention_kwargs, + )[0] + ms_pyexecutor.set_async_for_graph(True) + else: + if i < self.tgate: + latent_model_input = ops.cat([latents] * 2) if self.do_classifier_free_guidance else latents + else: + if i == self.tgate: + _, delta_cache = mint.chunk(delta_cache, 2) + _, delta_cache_hidden = mint.chunk(delta_cache_hidden, 2) + latent_model_input = latents + timestep = t.broadcast_to((latent_model_input.shape[0],)) + + ms_pyexecutor.sync() + ms_pyexecutor.set_async_for_graph(False) + if i < (self.cache_params[3] - 1): + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + pooled_projections=pooled_prompt_embeds, + joint_attention_kwargs=self.joint_attention_kwargs, + cache_params=self.cache_params, + if_skip=False, + use_cache=False, + delta_cache=delta_cache, + delta_cache_hidden=delta_cache_hidden, + )[0] + else: + noise_pred, delta_cache, delta_cache_hidden = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds if i < self.tgate else prompt_embeds_origin, + pooled_projections=pooled_prompt_embeds if i < self.tgate else pooled_prompt_embeds_origin, + joint_attention_kwargs=self.joint_attention_kwargs, + cache_params=self.cache_params, + if_skip=((i >= self.cache_params[3]) and (i % cache_interval == step_constrast)), + use_cache=True, + delta_cache=delta_cache, + delta_cache_hidden=delta_cache_hidden, + ) + ms_pyexecutor.set_async_for_graph(True) + + # perform guidance + if self.do_classifier_free_guidance and (not self.use_cache_and_tgate or i < self.tgate): + noise_pred_uncond, noise_pred_text = mint.chunk(noise_pred, 2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + negative_pooled_prompt_embeds = callback_outputs.pop( + "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds + ) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if lora_scale is not None: + # remove `lora_scale` from each PEFT layer + unscale_lora_layers(self.transformer, lora_scale) + + if output_type == "latent": + image = latents + + else: + latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor + latents = latents.to( + self.vae.dtype + ) # for validation in training where vae and transformer might have different dtype + + image = self.vae.decode(latents, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + if not return_dict: + return (image,) + + return StableDiffusion3PipelineOutput(images=image) \ No newline at end of file diff --git a/mindone/diffusers/models/embeddings.py b/mindone/diffusers/models/embeddings.py index bb74c4ed2f2d00aaff8519ae774b75ec6c212944..86f2009e5e8a1f72ff081c8f93fd60ab911d00fd 100644 --- a/mindone/diffusers/models/embeddings.py +++ b/mindone/diffusers/models/embeddings.py @@ -17,7 +17,7 @@ from typing import List, Optional, Tuple, Union import numpy as np import mindspore as ms -from mindspore import nn, ops +from mindspore import nn, ops, mint from .activations import FP32SiLU, get_activation from .attention_processor import Attention @@ -56,7 +56,8 @@ def get_timestep_embedding( # flip sine and cosine embeddings if flip_sin_to_cos: - emb = ops.cat([emb[:, half_dim:], emb[:, :half_dim]], axis=-1) + sin, cos = mint.split(emb, half_dim, dim=1) + emb = ops.cat((cos, sin), axis=-1) # zero pad if embedding_dim % 2 == 1: diff --git a/mindone/diffusers/models/normalization.py b/mindone/diffusers/models/normalization.py index e4e4b98a85461cfec86b94e572ef5455a0a65a92..88225cdbaaa3ee65aa7ce286d3d8ee7d6d1e2a48 100644 --- a/mindone/diffusers/models/normalization.py +++ b/mindone/diffusers/models/normalization.py @@ -290,7 +290,7 @@ class LayerNorm(nn.Cell): self.layer_norm = ops.LayerNorm(-1, -1, epsilon=eps) def construct(self, x: Tensor): - x, _, _ = self.layer_norm(x, self.weight, self.bias) + x, _, _ = self.layer_norm(x, self.weight.to(x.dtype), self.bias.to(x.dtype)) return x diff --git a/mindone/diffusers/models/transformers/t5_film_transformer.py b/mindone/diffusers/models/transformers/t5_film_transformer.py index a4691e437f2987a64dc72b9f12804e0a2f2f55ac..f6ed2f212b25aa6d4a3fce6b27a67b2d6aac8cc9 100644 --- a/mindone/diffusers/models/transformers/t5_film_transformer.py +++ b/mindone/diffusers/models/transformers/t5_film_transformer.py @@ -14,7 +14,7 @@ from typing import Optional, Tuple import mindspore as ms -from mindspore import nn, ops +from mindspore import nn, ops, mint from ...configuration_utils import ConfigMixin, register_to_config from ..attention_processor import Attention @@ -409,7 +409,7 @@ class NewGELUActivation(nn.Cell): def construct(self, input: ms.Tensor) -> ms.Tensor: # Magic number 0.797885 comes from math.sqrt(2.0 / math.pi) as float32 - return 0.5 * input * (1.0 + ops.tanh(0.797885 * (input + 0.044715 * ops.pow(input, 3.0)))) + return mint.nn.functional.gelu(input) class T5FiLMLayer(nn.Cell): diff --git a/mindone/diffusers/schedulers/scheduling_flow_match_euler_discrete.py b/mindone/diffusers/schedulers/scheduling_flow_match_euler_discrete.py index 553878785566d65e3f555f1bf5f92d546dfee59e..ec220c3501e9309f417474a6509fbf6f845d84fa 100644 --- a/mindone/diffusers/schedulers/scheduling_flow_match_euler_discrete.py +++ b/mindone/diffusers/schedulers/scheduling_flow_match_euler_discrete.py @@ -247,13 +247,12 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin): sigma = self.sigmas[self.step_index] gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0.0 - - noise = randn_tensor(model_output.shape, dtype=model_output.dtype, generator=generator) - - eps = noise * s_noise + sigma_hat = sigma * (gamma + 1) if gamma > 0: + noise = randn_tensor(model_output.shape, dtype=model_output.dtype, generator=generator) + eps = noise * s_noise sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5 # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise