diff --git a/MindFlow/mindflow/cell/__init__.py b/MindFlow/mindflow/cell/__init__.py index 59bbc9634b4fd0f472e468f5e9a69790000c8a5d..437a84319a6ff880ca3f107724677d15c959923b 100644 --- a/MindFlow/mindflow/cell/__init__.py +++ b/MindFlow/mindflow/cell/__init__.py @@ -16,7 +16,7 @@ from .activation import get_activation from .basic_block import LinearBlock, ResBlock, InputScale, FCSequential, MultiScaleFCSequential, DropPath from .neural_operators import FNO1D, FNO2D, FNO3D, KNO1D, KNO2D, PDENet, PeRCNN, SNO, SNO1D, SNO2D, SNO3D -from .attention import Attention, MultiHeadAttention, AttentionBlock +from .attention import Attention, MultiHeadAttention, AttentionBlock, FlashAttention from .vit import ViT from .unet2d import UNet2D from .sno_utils import poly_data, get_poly_transform, interpolate_1d_dataset, interpolate_2d_dataset @@ -24,8 +24,8 @@ from .diffusion import DiffusionScheduler, DiffusionTrainer, DDPMScheduler, DDIM from .diffusion_transformer import DiffusionTransformer, ConditionDiffusionTransformer __all__ = ["get_activation", "FNO1D", "FNO2D", "FNO3D", "KNO1D", "KNO2D", "PDENet", "UNet2D", "PeRCNN", - "SNO", "SNO1D", "SNO2D", "SNO3D", "Attention", "MultiHeadAttention", "AttentionBlock", "ViT", "DDPMPipeline", - "DDIMPipeline", "DiffusionTrainer", "DiffusionScheduler", "DDPMScheduler", "DDIMScheduler", - "DiffusionTransformer", "ConditionDiffusionTransformer"] + "SNO", "SNO1D", "SNO2D", "SNO3D", "Attention", "MultiHeadAttention", "AttentionBlock", "FlashAttention", + "ViT", "DDPMPipeline", "DDIMPipeline", "DiffusionTrainer", "DiffusionScheduler", "DDPMScheduler", + "DDIMScheduler", "DiffusionTransformer", "ConditionDiffusionTransformer"] __all__.extend(basic_block.__all__) __all__.extend(sno_utils.__all__) diff --git a/MindFlow/mindflow/cell/attention.py b/MindFlow/mindflow/cell/attention.py index b100dd8cbee681b100859a3f25ce1b7ff8e9b82d..a6e814c9d1dbbd6006fd05ac7a5f79821b51c9c1 100644 --- a/MindFlow/mindflow/cell/attention.py +++ b/MindFlow/mindflow/cell/attention.py @@ -185,6 +185,67 @@ class MultiHeadAttention(Attention): return output +class FlashAttention(Attention): + r"""FlashAttention proposed in `FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness `_. + + Args: + in_channels (int): The input channels. + num_heads (int): The number of attention heads. + fa_dtype (mindspore.dtype): Flash attention compute dtype. Choose from `mstype.bfloat16`, `mstype.float16`. + Default: ``mstype.bfloat16``, indicates ``mindspore.bfloat16``. + + Inputs: + - **x** (Tensor) - Tensor with shape :math:`(batch\_size, sequence\_len, in\_channels)`. + - **attn_mask** (Tensor) - Tensor with shape :math:`(sequence\_len, sequence\_len)` + or :math:`(batch\_size, 1, sequence\_len, sequence\_len)`. + + Outputs: + - **output** (Tensor) - Tensor with shape :math:`(batch\_size, sequence\_len, in\_channels)`. + + Supported Platforms: + ``Ascend`` ``CPU`` + + Examples: + >>> from mindspore import ops + >>> from mindflow.cell import FlashAttention + >>> model = FlashAttention(in_channels=512, num_heads=4) + >>> x = ops.rand((2, 32, 512)) + >>> mask_shape = (2, 1, 32, 32) + >>> mask = ops.ones(mask_shape) + >>> output = model(x, mask) + >>> print(output.shape) + (2, 32, 512) + """ + + def __init__(self, in_channels, num_heads, fa_dtype=mstype.bfloat16): + super().__init__(in_channels, num_heads, compute_dtype=mstype.float32) + assert ( + in_channels % num_heads == 0 + ), "hidden channels must be divisible by number of heads" + assert ( + fa_dtype in (mstype.bfloat16, mstype.float16) + ), "FlashAttention only support bfloat16 and float16" + self.scale = (in_channels // num_heads) ** -0.5 + self.proj = nn.Dense(in_channels, in_channels) + self.fa_dtype = fa_dtype + self.num_heads = num_heads + + # pylint: disable=W0221 + def construct(self, x, attn_mask=None): + """construct""" + batch, node, _ = x.shape + query, key, value = self.get_qkv(x) + query, key, value = query.astype(self.fa_dtype), key.astype(self.fa_dtype), value.astype(self.fa_dtype) + if attn_mask is not None: + attn_mask = attn_mask.astype(mstype.uint8) + scores = ops.flash_attention_score(query, key, value, input_layout='BNSD', head_num=self.num_heads, + attn_mask=attn_mask, scalar_value=self.scale) + scores = ops.transpose(scores, (0, 2, 1, 3)).reshape(batch, node, -1) + scores = scores.astype(mstype.float32) + output = self.proj(scores) + return output + + class Mlp(nn.Cell): """Mlp""" diff --git a/tests/st/mindflow/cell/attention/test_attention.py b/tests/st/mindflow/cell/attention/test_attention.py index aceb878ba01e3cb86bd1c2a77aadad175395477d..1c0b672dd3fc44a6ce4f609658e22839fc7a2978 100644 --- a/tests/st/mindflow/cell/attention/test_attention.py +++ b/tests/st/mindflow/cell/attention/test_attention.py @@ -21,7 +21,7 @@ import numpy as np from mindspore import Tensor, ops, load_checkpoint, load_param_into_net, jit_class, context from mindspore import dtype as mstype -from mindflow.cell import Attention, MultiHeadAttention, AttentionBlock, DropPath, ViT +from mindflow.cell import Attention, MultiHeadAttention, AttentionBlock, DropPath, ViT, FlashAttention from mindflow.core import RelativeRMSELoss PROJECT_ROOT = os.path.abspath(os.path.join( @@ -79,6 +79,27 @@ def test_attention_dtype(mode): assert v.dtype == mstype.float16 +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend910b_training +@pytest.mark.env_onecard +@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) +@pytest.mark.parametrize('fa_dtype', [mstype.float16, mstype.bfloat16]) +def test_flashattention(mode, fa_dtype): + """ + Feature: FlashAttention + Description: test forward result + Expectation: success + """ + context.set_context(mode=mode) + net = FlashAttention(IN_CHANNELS, NUM_HEADS, fa_dtype=fa_dtype) + in_shape = (BATCH_SIZE, SEQ_LEN, IN_CHANNELS) + x = ops.randn(in_shape) + mask = ops.randint(0, 2, (BATCH_SIZE, 1, SEQ_LEN, SEQ_LEN)) + output = net(x, mask) + assert output.dtype == mstype.float32 + assert output.shape == in_shape + + # pylint: disable=W0212 @pytest.mark.level0 @pytest.mark.platform_arm_ascend910b_training