diff --git a/MindFlow/mindflow/cell/__init__.py b/MindFlow/mindflow/cell/__init__.py index 80bb713393c8875d8d5d07253f5ed1c3e89fac22..59bbc9634b4fd0f472e468f5e9a69790000c8a5d 100644 --- a/MindFlow/mindflow/cell/__init__.py +++ b/MindFlow/mindflow/cell/__init__.py @@ -15,7 +15,7 @@ """init""" from .activation import get_activation from .basic_block import LinearBlock, ResBlock, InputScale, FCSequential, MultiScaleFCSequential, DropPath -from .neural_operators import FNO1D, FNO2D, FNO3D, KNO1D, KNO2D, PDENet, PeRCNN, SNO1D, SNO2D, SNO3D +from .neural_operators import FNO1D, FNO2D, FNO3D, KNO1D, KNO2D, PDENet, PeRCNN, SNO, SNO1D, SNO2D, SNO3D from .attention import Attention, MultiHeadAttention, AttentionBlock from .vit import ViT from .unet2d import UNet2D @@ -24,7 +24,7 @@ from .diffusion import DiffusionScheduler, DiffusionTrainer, DDPMScheduler, DDIM from .diffusion_transformer import DiffusionTransformer, ConditionDiffusionTransformer __all__ = ["get_activation", "FNO1D", "FNO2D", "FNO3D", "KNO1D", "KNO2D", "PDENet", "UNet2D", "PeRCNN", - "SNO1D", "SNO2D", "SNO3D", "Attention", "MultiHeadAttention", "AttentionBlock", "ViT", "DDPMPipeline", + "SNO", "SNO1D", "SNO2D", "SNO3D", "Attention", "MultiHeadAttention", "AttentionBlock", "ViT", "DDPMPipeline", "DDIMPipeline", "DiffusionTrainer", "DiffusionScheduler", "DDPMScheduler", "DDIMScheduler", "DiffusionTransformer", "ConditionDiffusionTransformer"] __all__.extend(basic_block.__all__) diff --git a/MindFlow/mindflow/cell/neural_operators/__init__.py b/MindFlow/mindflow/cell/neural_operators/__init__.py index 60d48217594f91edcd46d41ca8c2698e0e891654..4498dba834ce688ca598823d15b63a5a1be1b0af 100644 --- a/MindFlow/mindflow/cell/neural_operators/__init__.py +++ b/MindFlow/mindflow/cell/neural_operators/__init__.py @@ -18,8 +18,9 @@ from .kno1d import KNO1D from .kno2d import KNO2D from .pdenet import PDENet from .percnn import PeRCNN -from .sno import SNO1D, SNO2D, SNO3D +from .sno import SNO, SNO1D, SNO2D, SNO3D -__all__ = ["FNOBlocks", "FNO1D", "FNO2D", "FNO3D", "KNO1D", "KNO2D", "PDENet", "PeRCNN", "SNO1D", "SNO2D", "SNO3D"] +__all__ = ["FNOBlocks", "FNO1D", "FNO2D", "FNO3D", "KNO1D", "KNO2D", "PDENet", "PeRCNN", + "SNO", "SNO1D", "SNO2D", "SNO3D"] __all__.sort()