diff --git a/MindIE/MultiModal/Flux.1-DEV/FLUX1dev/layers/attention_processor.py b/MindIE/MultiModal/Flux.1-DEV/FLUX1dev/layers/attention_processor.py index cad77988f9f5afcfa8cfb871a967531cfca34bac..4fb6370b5dd9ec4c56acebb185cc71198f51a6e0 100644 --- a/MindIE/MultiModal/Flux.1-DEV/FLUX1dev/layers/attention_processor.py +++ b/MindIE/MultiModal/Flux.1-DEV/FLUX1dev/layers/attention_processor.py @@ -30,8 +30,10 @@ from ..utils import get_local_rank, get_world_size logger = logging.get_logger(__name__) -def apply_rotary_emb_mindspeed(x, freqs_cis): +def apply_rotary_emb_mindiesd(x, freqs_cis): cos, sin = freqs_cis + cos = cos[None, None] + sin = sin[None, None] cos, sin = cos.to(x.device), sin.to(x.device) return rotary_position_embedding(x, cos, sin, rotated_mode="rotated_interleaved", head_first=False, fused=True) @@ -374,8 +376,8 @@ class FluxSingleAttnProcessor2_0: # Apply RoPE if needed if image_rotary_emb is not None: - query = apply_rotary_emb_mindspeed(query, image_rotary_emb) - key = apply_rotary_emb_mindspeed(key, image_rotary_emb) + query = apply_rotary_emb_mindiesd(query, image_rotary_emb) + key = apply_rotary_emb_mindiesd(key, image_rotary_emb) # the output of sdp = (batch, num_heads, seq_len, head_dim) hidden_states = apply_fa(query, key, value, attention_mask) @@ -465,8 +467,8 @@ class FluxAttnProcessor2_0: value = torch.cat([encoder_hidden_states_value_proj, value], dim=1) if image_rotary_emb is not None: - query = apply_rotary_emb_mindspeed(query, image_rotary_emb) - key = apply_rotary_emb_mindspeed(key, image_rotary_emb) + query = apply_rotary_emb_mindiesd(query, image_rotary_emb) + key = apply_rotary_emb_mindiesd(key, image_rotary_emb) hidden_states = apply_fa(query, key, value, attention_mask)