From a9c2489e6e1800589727ba8bdb11517a087622c0 Mon Sep 17 00:00:00 2001 From: Schizobulia <490124601@qq.com> Date: Tue, 30 Sep 2025 02:41:05 +0000 Subject: [PATCH] =?UTF-8?q?GINO=E9=83=A8=E5=88=86=E6=A8=A1=E5=9D=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Schizobulia <490124601@qq.com> --- MindFlow/mindflow/GINO/layers/assert_close.py | 91 +++ .../GINO/layers/base_spectral_conv.py | 28 + MindFlow/mindflow/GINO/layers/channel_mlp.py | 117 ++++ MindFlow/mindflow/GINO/layers/embeddings.py | 374 ++++++++++++ MindFlow/mindflow/GINO/layers/fno_block.py | 429 ++++++++++++++ .../GINO/layers/gno_weighting_functions.py | 58 ++ .../GINO/layers/integral_transform.py | 243 ++++++++ .../mindflow/GINO/layers/neighbor_search.py | 116 ++++ MindFlow/mindflow/GINO/layers/segment_csr.py | 100 ++++ .../GINO/layers/tests/test_assert_close.py | 74 +++ .../GINO/layers/tests/test_gno_block.py | 116 ++++ .../GINO/layers/tests/test_grid_embeddings.py | 45 ++ .../GINO/layers/tests/test_neighbor_search.py | 49 ++ .../GINO/layers/tests/test_segment_csr.py | 47 ++ .../GINO/layers/tests/test_sin_embeddings.py | 96 ++++ MindFlow/mindflow/GINO/models/gino.py | 533 ++++++++++++++++++ 16 files changed, 2516 insertions(+) create mode 100644 MindFlow/mindflow/GINO/layers/assert_close.py create mode 100644 MindFlow/mindflow/GINO/layers/base_spectral_conv.py create mode 100644 MindFlow/mindflow/GINO/layers/channel_mlp.py create mode 100644 MindFlow/mindflow/GINO/layers/embeddings.py create mode 100644 MindFlow/mindflow/GINO/layers/fno_block.py create mode 100644 MindFlow/mindflow/GINO/layers/gno_weighting_functions.py create mode 100644 MindFlow/mindflow/GINO/layers/integral_transform.py create mode 100644 MindFlow/mindflow/GINO/layers/neighbor_search.py create mode 100644 MindFlow/mindflow/GINO/layers/segment_csr.py create mode 100644 MindFlow/mindflow/GINO/layers/tests/test_assert_close.py create mode 100644 MindFlow/mindflow/GINO/layers/tests/test_gno_block.py create mode 100644 MindFlow/mindflow/GINO/layers/tests/test_grid_embeddings.py create mode 100644 MindFlow/mindflow/GINO/layers/tests/test_neighbor_search.py create mode 100644 MindFlow/mindflow/GINO/layers/tests/test_segment_csr.py create mode 100644 MindFlow/mindflow/GINO/layers/tests/test_sin_embeddings.py create mode 100644 MindFlow/mindflow/GINO/models/gino.py diff --git a/MindFlow/mindflow/GINO/layers/assert_close.py b/MindFlow/mindflow/GINO/layers/assert_close.py new file mode 100644 index 000000000..fd95279b9 --- /dev/null +++ b/MindFlow/mindflow/GINO/layers/assert_close.py @@ -0,0 +1,91 @@ +import mindspore +from mindspore import mint, ops +import numpy as np + +def assert_close( + actual, + expected, + rtol=1e-05, + atol=1e-08, + equal_nan=False, + check_dtype=True, + check_stride=False, + msg=None, +): + """ + 自定义实现的断言函数,用于检查两个张量是否在给定的容差范围内相等。 + + 参数: + actual: 实际张量 + expected: 期望张量 + rtol: 相对容差 + atol: 绝对容差 + equal_nan: 如果为True,则NaN值被视为相等 + check_dtype: 是否检查数据类型一致性 + check_stride: 是否检查步长一致性 + msg: 断言失败时的自定义错误消息 + """ + # 检查是否为张量 + if not isinstance(actual, mindspore.Tensor): + actual = mindspore.tensor(actual) + if not isinstance(expected, mindspore.Tensor): + expected = mindspore.tensor(expected) + + # 检查数据类型一致性 + if check_dtype and actual.dtype != expected.dtype: + raise AssertionError( + f"张量数据类型不匹配: 实际({actual.dtype}) vs 期望({expected.dtype})" + ) + + # 检查形状一致性 + if actual.shape != expected.shape: + raise AssertionError( + f"张量形状不匹配: 实际({actual.shape}) vs 期望({expected.shape})" + ) + + # 检查步长一致性(如果需要) + if check_stride and actual.stride() != expected.stride(): + raise AssertionError( + f"张量步长不匹配: 实际({actual.stride()}) vs 期望({expected.stride()})" + ) + + # 计算最大绝对误差 + max_abs_error = mint.max(mint.abs(actual - expected)).item() + + # 计算相对容差阈值 + tolerance = atol + rtol * mint.max(mint.abs(expected)).item() + + # 检查是否有超出容差的元素 + if max_abs_error > tolerance: + # 查找第一个超出容差的元素位置 + diff_mask = mint.abs(actual - expected) > tolerance + first_diff_idx = tuple(mint.nonzero(diff_mask, as_tuple=True)) + + # 获取第一个超出容差的元素值 + actual_val = actual[first_diff_idx].item() + expected_val = expected[first_diff_idx].item() + + # 构建错误消息 + default_msg = ( + f"张量不接近。最大绝对误差={max_abs_error:.6e}," + f"超过容差={tolerance:.6e}," + f"第一个不匹配位置={first_diff_idx}," + f"实际值={actual_val:.6e},期望值={expected_val:.6e}" + ) + + raise AssertionError(msg or default_msg) + + # 检查NaN值处理 + if not equal_nan: + actual_has_nan = ops.isnan(actual).any().item() + expected_has_nan = ops.isnan(expected).any().item() + + if actual_has_nan != expected_has_nan: + raise AssertionError( + f"张量中NaN存在性不一致: 实际有{actual_has_nan}个NaN,期望有{expected_has_nan}个NaN" + ) + + if actual_has_nan and expected_has_nan: + # 检查NaN位置是否相同 + if not ops.isnan(actual).equal(ops.isnan(expected)): + raise AssertionError("张量中NaN位置不一致") \ No newline at end of file diff --git a/MindFlow/mindflow/GINO/layers/base_spectral_conv.py b/MindFlow/mindflow/GINO/layers/base_spectral_conv.py new file mode 100644 index 000000000..4cc601277 --- /dev/null +++ b/MindFlow/mindflow/GINO/layers/base_spectral_conv.py @@ -0,0 +1,28 @@ +# from torch import nn +from mindspore import mint + +class BaseSpectralConv(mint.nn.Cell): +# class BaseSpectralConv(nn.Module): + def __init__(self, device=None, dtype=None): + """Base Class for Spectral Convolutions + + Use it when you want to build your own FNO-type Neural Operators + """ + super().__init__() + + self.dtype = dtype + self.device = device + + def transform(self, x): + """Transforms an input x for a skip connection, by default just an identity map + + If your function transforms the input then you should also implement this transform method + so the skip connection can also work. + + Typical usecases are: + + * Your upsample or downsample the input in the Spectral conv: the skip connection has to be similarly scaled. + This allows you to deal with it however you want (e.g. avoid aliasing) + * You perform a change of basis in your Spectral Conv, again, this needs to be applied to the skip connection too. + """ + return x diff --git a/MindFlow/mindflow/GINO/layers/channel_mlp.py b/MindFlow/mindflow/GINO/layers/channel_mlp.py new file mode 100644 index 000000000..39e6c66a7 --- /dev/null +++ b/MindFlow/mindflow/GINO/layers/channel_mlp.py @@ -0,0 +1,117 @@ +import mindspore +from mindspore.mint import nn +from mindspore.nn import Cell, CellList +import mindspore.mint.nn.functional as F + + +class ChannelMLP(Cell): + """ChannelMLP applies an arbitrary number of layers of + 1d convolution and nonlinearity to the channels of input + and is invariant to spatial resolution. + + Parameters + ---------- + in_channels : int + out_channels : int, default is None + if None, same is in_channels + hidden_channels : int, default is None + if None, same is in_channels + n_layers : int, default is 2 + number of linear layers in the MLP + non_linearity : default is F.gelu + dropout : float, default is 0 + if > 0, dropout probability + """ + + def __init__( + self, + in_channels, + out_channels=None, + hidden_channels=None, + n_layers=2, + n_dim=2, + non_linearity=F.gelu, + dropout=0.0, + **kwargs, + ): + super().__init__() + self.n_layers = n_layers + self.in_channels = in_channels + self.out_channels = in_channels if out_channels is None else out_channels + self.hidden_channels = ( + in_channels if hidden_channels is None else hidden_channels + ) + self.non_linearity = non_linearity + self.dropout = ( + CellList([nn.Dropout(dropout) for _ in range(n_layers)]) + if dropout > 0.0 + else None + ) + + # we use nn.Conv1d for everything and roll data along the 1st data dim + self.fcs = CellList() + for i in range(n_layers): + if i == 0 and i == (n_layers - 1): + self.fcs.append(mindspore.nn.Conv1d(self.in_channels, self.out_channels, 1, has_bias=True, pad_mode='valid')) + elif i == 0: + self.fcs.append(mindspore.nn.Conv1d(self.in_channels, self.hidden_channels, 1, has_bias=True, pad_mode='valid')) + elif i == (n_layers - 1): + self.fcs.append(mindspore.nn.Conv1d(self.hidden_channels, self.out_channels, 1, has_bias=True, pad_mode='valid')) + else: + self.fcs.append(mindspore.nn.Conv1d(self.hidden_channels, self.hidden_channels, 1, has_bias=True, pad_mode='valid')) + + def construct(self, x): + reshaped = False + size = list(x.shape) + if x.ndim > 3: + # batch, channels, x1, x2... extra dims + # .reshape() is preferable but .view() + # cannot be called on non-contiguous tensors + x = x.reshape((*size[:2], -1)) + reshaped = True + + for i, fc in enumerate(self.fcs): + x = fc(x) + if i < self.n_layers - 1: + x = self.non_linearity(x) + if self.dropout is not None: + x = self.dropout[i](x) + + # if x was an N-d tensor reshaped into 1d, undo the reshaping + # same logic as above: .reshape() handles contiguous tensors as well + if reshaped: + x = x.reshape((size[0], self.out_channels, *size[2:])) + + return x + + +# Reimplementation of the ChannelMLP class using Linear instead of Conv +class LinearChannelMLP(Cell): + def __init__(self, layers, non_linearity=F.gelu, dropout=0.0): + super().__init__() + + self.n_layers = len(layers) - 1 + + assert self.n_layers >= 1, "Error: trying to instantiate \ + a LinearChannelMLP with only one linear layer." + + self.fcs = CellList() + self.non_linearity = non_linearity + self.dropout = ( + CellList([nn.Dropout(dropout) for _ in range(self.n_layers)]) + if dropout > 0.0 + else None + ) + + for j in range(self.n_layers): + self.fcs.append(nn.Linear(layers[j], layers[j + 1])) + + def construct(self, x): + for i, fc in enumerate(self.fcs): + x = fc(x) + if i < self.n_layers - 1: + x = self.non_linearity(x) + if self.dropout is not None: + x = self.dropout[i](x) + + return x diff --git a/MindFlow/mindflow/GINO/layers/embeddings.py b/MindFlow/mindflow/GINO/layers/embeddings.py new file mode 100644 index 000000000..3f61efb85 --- /dev/null +++ b/MindFlow/mindflow/GINO/layers/embeddings.py @@ -0,0 +1,374 @@ +from abc import ABC, abstractmethod +from typing import List + +import math +from mindspore import mint +from mindspore.nn import Cell + + +class Embedding(Cell, ABC): + def __init__(self): + super().__init__() + + @property + @abstractmethod + def out_channels(self): + pass + +class GridEmbedding2D(Embedding): + """ + ``GridEmbedding2D`` applies a simple positional + embedding as a regular 2D grid. Expects inputs of shape + ``(batch, channels, d_1, d_2)`` + + Parameters + ---------- + in_channels : ``int`` + number of channels in input. Fixed for output channel interface + grid_boundaries : ``list``, optional + coordinate boundaries of input grid, by default [[0, 1], [0, 1]] + """ + def __init__(self, in_channels: int, grid_boundaries=[[0, 1], [0, 1]]): + super().__init__() + self.in_channels = in_channels + self.grid_boundaries = grid_boundaries + self._grid = None + self._res = None + + @property + def out_channels(self): + return self.in_channels + 2 + + def grid(self, spatial_dims, dtype): + """grid generates 2D grid needed for pos encoding + and caches the grid associated with MRU resolution + + Parameters + ---------- + spatial_dims : mindspore.shape + sizes of spatial resolution + dtype : str + dtype to encode data + + Returns + ------- + mindspore.tensor + output grids to concatenate + """ + # handle case of multiple train resolutions + if self._grid is None or self._res != spatial_dims: + grid_x, grid_y = regular_grid_2d(spatial_dims, + grid_boundaries=self.grid_boundaries) + grid_x = grid_x.to(dtype).unsqueeze(0).unsqueeze(0) + grid_y = grid_y.to(dtype).unsqueeze(0).unsqueeze(0) + self._grid = grid_x, grid_y + self._res = spatial_dims + + return self._grid + + def construct(self, data, batched=True): + if not batched: + if data.ndim == 3: + data = data.unsqueeze(0) + batch_size = data.shape[0] + x, y = self.grid(data.shape[-2:], data.dtype) + out = mint.cat((data, x.expand((batch_size, -1, -1, -1)), + y.expand((batch_size, -1, -1, -1))), + dim=1) + # in the unbatched case, the dataloader will stack N + # examples with no batch dim to create one + if not batched and batch_size == 1: + return out.squeeze(0) + else: + return out + +class GridEmbeddingND(Cell): + """ + GridEmbeddingND applies a simple positional + embedding as a regular ND grid. Expects inputs of shape + ``(batch, channels, d_1, ..., d_n)``. + + Parameters + ---------- + in_channels : int + number of channels in input + dim : int + dimensions of positional encoding to apply + grid_boundaries : list, optional + coordinate boundaries of input grid along each dim, by default [[0, 1], [0, 1]] + """ + def __init__(self, in_channels: int, dim: int=2, grid_boundaries=[[0, 1], [0, 1]]): + super().__init__() + self.in_channels = in_channels + self.dim = dim + assert self.dim == len(grid_boundaries), f"Error: expected grid_boundaries to be\ + an iterable of length {self.dim}, received {grid_boundaries}" + self.grid_boundaries = grid_boundaries + self._grid = None + self._res = None + + @property + def out_channels(self): + return self.in_channels + self.dim + + def grid(self, spatial_dims, dtype): + """grid generates ND grid needed for pos encoding + and caches the grid associated with MRU resolution + + Parameters + ---------- + spatial_dims : mindspore.Shape + sizes of spatial resolution + dtype : str + dtype to encode data + + Returns + ------- + mindspore.tensor + output grids to concatenate + """ + # handle case of multiple train resolutions + if self._grid is None or self._res != spatial_dims: + grids_by_dim = regular_grid_nd(spatial_dims, + grid_boundaries=self.grid_boundaries) + # add batch, channel dims + grids_by_dim = [x.to(dtype).unsqueeze(0).unsqueeze(0) for x in grids_by_dim] + self._grid = grids_by_dim + self._res = spatial_dims + + return self._grid + + def construct(self, data, batched=True): + """ + Params + -------- + data: mindspore.Tensor + assumes shape ``(batch (optional), channels, x_1, x_2, ...x_n)`` + batched: bool + whether data has a batch dim + """ + # add batch dim if it doesn't exist + if not batched: + if data.ndim == self.dim + 1: + data = data.unsqueeze(0) + batch_size = data.shape[0] + grids = self.grid(spatial_dims=data.shape[2:], + dtype=data.dtype) + grids = [x.repeat(batch_size, *[1] * (self.dim+1)) for x in grids] + out = mint.cat((data, *grids), + dim=1) + return out + +class SinusoidalEmbedding(Embedding): + """ + SinusoidalEmbedding provides a unified sinusoidal positional embedding + in the styles of Transformers [1]_ and Neural Radiance Fields (NERFs) [2]_. + + Expects inputs of shape ``(batch, n_in, in_channels)`` or ``(n_in, in_channels)`` + Parameters + ---------- + in_channels : ``int`` + Number of input channels to embed + num_freqs : ``int``, optional + Number of frequencies in positional embedding. + By default, set to the number of input channels + embedding : ``{'transformer', 'nerf'}`` + Type of embedding to apply. For a function with N input channels, + each channel value p is embedded via a function g with 2L channels + such that g(p) is a 2L-dim vector. For 0 <= k < L: + + * ``'transformer'`` for transformer-style encoding. + + g(p)_k = sin((p / max_positions) ^ {k / N}) + + g(p)_{k+1} = cos((p / max_positions) ^ {k / N}) + + * ``'nerf'`` : NERF-style encoding. + + g(p)_k = sin(2^(k) * Pi * p) + + g(p)_{k+1} = cos(2^(k) * Pi * p) + + max_positions : ``int``, optional + Maximum number of positions for the encoding, default 10000 + Only used if `embedding == transformer`. + + References + ----------- + .. [1] : Vaswani, A. et al (2017) + "Attention Is All You Need". + NeurIPS 2017, https://proceedings.neurips.cc/paper_files/paper/2017/file/3f5ee243547dee91fbd053c1c4a845aa-Paper.pdf. + + .. [2] : Mildenhall, B. et al (2020) + "NeRF: Representing Scenes as Neural Radiance Fields for View Synthesis". + ArXiv, https://arxiv.org/pdf/2003.08934. + """ + def __init__(self, + in_channels: int, + num_frequencies: int=None, + embedding_type: str='transformer', + max_positions: int=10000): + super().__init__() + self.in_channels = in_channels + self.num_frequencies = num_frequencies + + # verify embedding type + allowed_embeddings = ['nerf', 'transformer'] + assert embedding_type in allowed_embeddings, \ + f"Error: embedding_type expected one of {allowed_embeddings}, received {embedding_type}" + self.embedding_type = embedding_type + if self.embedding_type == "transformer": + assert max_positions is not None, "Error: max_positions must have an int value for \ + transformer embedding." + self.max_positions = max_positions + + + @property + def out_channels(self): + """ + required property for linking/composing model layers + """ + return 2 * self.num_frequencies * self.in_channels + + def construct(self, x): + """ + Parameters + ----------- + x: ``mindspore.Tensor`` + shape ``(n_in, self.in_channels)`` or ``(batch, n_in, self.in_channels)`` + """ + assert x.ndim in [2,3], f"Error: expected inputs of shape (batch, n_in, {self.in_channels})\ + or (n_in, channels), got inputs with ndim={x.ndim}, shape={x.shape}" + if x.ndim == 2: + batched = False + x = x.unsqueeze(0) + else: + batched = True + batch_size, n_in, _ = x.shape + + if self.embedding_type == 'nerf': + freqs = 2 ** mint.arange(0, self.num_frequencies) * math.pi + + elif self.embedding_type == 'transformer': + freqs = mint.arange(0, self.num_frequencies) / (self.num_frequencies * 2) + freqs = (1 / self.max_positions) ** freqs + + # outer product of wavenumbers and position coordinates + # shape b, n_in * channels, len(freqs) + freqs = mint.einsum('bij, k -> bijk', x, freqs) + + # shape len(x), 2, len(freqs) + freqs = mint.stack((freqs.sin(),freqs.cos()), dim=-1) + + # transpose the inner per-entry matrix and ravel to interleave sin and cos + freqs = freqs.view(batch_size, n_in, -1) + + if not batched: + freqs = freqs.squeeze(0) + return freqs + +class RotaryEmbedding2D(Cell): + def __init__(self, dim, min_freq=1/64, scale=1.): + """ + Applying rotary positional embedding (https://arxiv.org/abs/2104.09864) to the input feature tensor. + The crux is the dot product of two rotation matrices R(theta1) and R(theta2) is equal to R(theta2 - theta1). + """ + super().__init__() + inv_freq = 1. / (10000 ** (mint.arange(0, dim, 2).float() / dim)) + self.min_freq = min_freq + self.scale = scale + self.register_buffer('inv_freq', inv_freq, persistent=False) + self.out_channels = 2 + + def construct(self, coordinates): + """coordinates is tensor of [batch_size, num_points]""" + coordinates = coordinates * (self.scale / self.min_freq) + freqs = mint.einsum('... i , j -> ... i j', coordinates, self.inv_freq) # [b, n, d//2] + return mint.cat((freqs, freqs), dim=-1) # [b, n, d] + + @staticmethod + def apply_1d_rotary_pos_emb(t, freqs): + return apply_rotary_pos_emb(t, freqs) + + @staticmethod + def apply_2d_rotary_pos_emb(t, freqs_x, freqs_y): + """Split the last dimension of features into two equal halves + and apply 1d rotary positional embedding to each half.""" + d = t.shape[-1] + t_x, t_y = t[..., :d//2], t[..., d//2:] + + return mint.cat((apply_rotary_pos_emb(t_x, freqs_x), + apply_rotary_pos_emb(t_y, freqs_y)), dim=-1) + +# Utility functions for GridEmbedding +def regular_grid_2d(spatial_dims, grid_boundaries=[[0, 1], [0, 1]]): + """ + Creates a 2 x height x width stack of positional encodings A, where + A[:,i,j] = [[x,y]] at coordinate (i,j) on a (height, width) grid. + """ + height, width = spatial_dims + + xt = mint.linspace(grid_boundaries[0][0], grid_boundaries[0][1], + height + 1)[:-1] + yt = mint.linspace(grid_boundaries[1][0], grid_boundaries[1][1], + width + 1)[:-1] + + grid_x, grid_y = mint.meshgrid(xt, yt, indexing='ij') + + grid_x = grid_x.repeat(1, 1) + grid_y = grid_y.repeat(1, 1) + + return grid_x, grid_y + +def regular_grid_nd(resolutions: List[int], grid_boundaries: List[List[int]]=[[0,1]] * 2): + """regular_grid_nd generates a tensor of coordinate points that + describe a bounded regular grid. + + Creates a dim x res_d1 x ... x res_dn stack of positional encodings A, where + A[:,c1,c2,...] = [[d1,d2,...dn]] at coordinate (c1,c2,...cn) on a (res_d1, ...res_dn) grid. + + Parameters + ---------- + resolutions : List[int] + resolution of the output grid along each dimension + grid_boundaries : List[List[int]], optional + List of pairs [start, end] of the boundaries of the + regular grid. Must correspond 1-to-1 with resolutions default [[0,1], [0,1]] + + Returns + ------- + grid: tuple(Tensor) + list of tensors describing positional encoding + """ + assert len(resolutions) == len(grid_boundaries), "Error: inputs must have same number of dimensions" + dim = len(resolutions) + + meshgrid_inputs = list() + for res, (start,stop) in zip(resolutions, grid_boundaries): + meshgrid_inputs.append(mint.linspace(start, stop, res + 1)[:-1]) + grid = mint.meshgrid(*meshgrid_inputs, indexing='ij') + grid = tuple([x.repeat([1]*dim) for x in grid]) + return grid + + +# Utility fucntions for Rotary embedding +# modified from https://github.com/lucidrains/x-transformers/blob/main/x_transformers/x_transformers.py +def rotate_half(x): + """ + Split x's channels into two equal halves. + """ + # split the last dimension of x into two equal halves + x = x.reshape(*x.shape[:-1], 2, -1) + x1, x2 = x.unbind(dim=-2) + return mint.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(t, freqs): + """ + Apply rotation matrix computed based on freqs to rotate t. + t: tensor of shape [batch_size, num_points, dim] + freqs: tensor of shape [batch_size, num_points, 1] + + Formula: see equation (34) in https://arxiv.org/pdf/2104.09864.pdf + """ + return (t * freqs.cos()) + (rotate_half(t) * freqs.sin()) diff --git a/MindFlow/mindflow/GINO/layers/fno_block.py b/MindFlow/mindflow/GINO/layers/fno_block.py new file mode 100644 index 000000000..005aa643d --- /dev/null +++ b/MindFlow/mindflow/GINO/layers/fno_block.py @@ -0,0 +1,429 @@ +from typing import List, Optional, Union + +# import torch +# from torch import nn +# import torch.nn.functional as F + +import mindspore +from mindspore import nn +import mindspore.mint.nn.functional as F + +from .channel_mlp import ChannelMLP +from .complex import CGELU, apply_complex, ctanh, ComplexValued +from .normalization_layers import AdaIN, InstanceNorm, BatchNorm +from .skip_connections import skip_connection +from .spectral_convolution import SpectralConv +from ..utils import validate_scaling_factor + + +Number = Union[int, float] + + +# class FNOBlocks(nn.Module): +class FNOBlocks(nn.Cell): + """FNOBlocks implements a sequence of Fourier layers, the operations of which + are first described in [1]_. The exact implementation details of the Fourier + layer architecture are discussed in [2]_. + + Parameters + ---------- + in_channels : int + input channels to Fourier layers + out_channels : int + output channels after Fourier layers + n_modes : int, List[int] + number of modes to keep along each dimension + in frequency space. Can either be specified as + an int (for all dimensions) or an iterable with one + number per dimension + resolution_scaling_factor : Optional[Union[Number, List[Number]]], optional + factor by which to scale outputs for super-resolution, by default None + n_layers : int, optional + number of Fourier layers to apply in sequence, by default 1 + max_n_modes : int, List[int], optional + maximum number of modes to keep along each dimension, by default None + fno_block_precision : str, optional + floating point precision to use for computations, by default "full" + use_channel_mlp : bool, optional + Whether to use an MLP layer after each FNO block, by default True + channel_mlp_dropout : int, optional + dropout parameter for self.channel_mlp, by default 0 + channel_mlp_expansion : float, optional + expansion parameter for self.channel_mlp, by default 0.5 + non_linearity : torch.nn.F module, optional + nonlinear activation function to use between layers, by default F.gelu + stabilizer : Literal["tanh"], optional + stabilizing module to use between certain layers, by default None + if "tanh", use tanh + norm : Literal["ada_in", "group_norm", "instance_norm", "batch_norm"], optional + Normalization layer to use, by default None + ada_in_features : int, optional + number of features for adaptive instance norm above, by default None + preactivation : bool, optional + whether to call forward pass with pre-activation, by default False + if True, call nonlinear activation and norm before Fourier convolution + if False, call activation and norms after Fourier convolutions + fno_skip : str, optional + module to use for FNO skip connections, by default "linear" + see layers.skip_connections for more details + channel_mlp_skip : str, optional + module to use for ChannelMLP skip connections, by default "soft-gating" + see layers.skip_connections for more details + + Other Parameters + ------------------- + complex_data : bool, optional + whether the FNO's data takes on complex values in space, by default False + separable : bool, optional + separable parameter for SpectralConv, by default False + factorization : str, optional + factorization parameter for SpectralConv, by default None + rank : float, optional + rank parameter for SpectralConv, by default 1.0 + conv_module : BaseConv, optional + module to use for convolutions in FNO block, by default SpectralConv + joint_factorization : bool, optional + whether to factorize all spectralConv weights as one tensor, by default False + fixed_rank_modes : bool, optional + fixed_rank_modes parameter for SpectralConv, by default False + implementation : str, optional + implementation parameter for SpectralConv, by default "factorized" + decomposition_kwargs : _type_, optional + kwargs for tensor decomposition in SpectralConv, by default dict() + + References + ----------- + .. [1] Li, Z. et al. "Fourier Neural Operator for Parametric Partial Differential + Equations" (2021). ICLR 2021, https://arxiv.org/pdf/2010.08895. + .. [2] Kossaifi, J., Kovachki, N., Azizzadenesheli, K., Anandkumar, A. "Multi-Grid + Tensorized Fourier Neural Operator for High-Resolution PDEs" (2024). + TMLR 2024, https://openreview.net/pdf?id=AWiDlO63bH. + """ + def __init__( + self, + in_channels, + out_channels, + n_modes, + resolution_scaling_factor=None, + n_layers=1, + max_n_modes=None, + fno_block_precision="full", + use_channel_mlp=True, + channel_mlp_dropout=0, + channel_mlp_expansion=0.5, + non_linearity=F.gelu, + stabilizer=None, + norm=None, + ada_in_features=None, + preactivation=False, + fno_skip="linear", + channel_mlp_skip="soft-gating", + complex_data=False, + separable=False, + factorization=None, + rank=1.0, + conv_module=SpectralConv, + fixed_rank_modes=False, #undoc + implementation="factorized", #undoc + decomposition_kwargs=dict(), + **kwargs, + ): + super().__init__() + if isinstance(n_modes, int): + n_modes = [n_modes] + self._n_modes = n_modes + self.n_dim = len(n_modes) + + self.resolution_scaling_factor: Union[ + None, List[List[float]] + ] = validate_scaling_factor(resolution_scaling_factor, self.n_dim, n_layers) + + self.max_n_modes = max_n_modes + self.fno_block_precision = fno_block_precision + self.in_channels = in_channels + self.out_channels = out_channels + self.n_layers = n_layers + self.stabilizer = stabilizer + self.rank = rank + self.factorization = factorization + self.fixed_rank_modes = fixed_rank_modes + self.decomposition_kwargs = decomposition_kwargs + self.fno_skip = fno_skip + self.channel_mlp_skip = channel_mlp_skip + self.complex_data = complex_data + + self.use_channel_mlp = use_channel_mlp + self.channel_mlp_expansion = channel_mlp_expansion + self.channel_mlp_dropout = channel_mlp_dropout + self.implementation = implementation + self.separable = separable + self.preactivation = preactivation + self.ada_in_features = ada_in_features + + # apply real nonlin if data is real, otherwise CGELU + if self.complex_data: + self.non_linearity = CGELU + else: + self.non_linearity = non_linearity + + # self.convs = nn.ModuleList([ + self.convs = nn.CellList([ + conv_module( + self.in_channels, + self.out_channels, + self.n_modes, + resolution_scaling_factor=None if resolution_scaling_factor is None else self.resolution_scaling_factor[i], + max_n_modes=max_n_modes, + rank=rank, + fixed_rank_modes=fixed_rank_modes, + implementation=implementation, + separable=separable, + factorization=factorization, + fno_block_precision=fno_block_precision, + decomposition_kwargs=decomposition_kwargs, + complex_data=complex_data + ) + for i in range(n_layers)]) + + # self.fno_skips = nn.ModuleList( + self.fno_skips = nn.CellList( + [ + skip_connection( + self.in_channels, + self.out_channels, + skip_type=fno_skip, + n_dim=self.n_dim, + ) + for _ in range(n_layers) + ] + ) + if self.complex_data: + # self.fno_skips = nn.ModuleList( + self.fno_skips = nn.CellList( + [ComplexValued(x) for x in self.fno_skips] + ) + + if self.use_channel_mlp: + # self.channel_mlp = nn.ModuleList( + self.channel_mlp = nn.CellList( + [ + ChannelMLP( + in_channels=self.out_channels, + hidden_channels=round(self.out_channels * channel_mlp_expansion), + dropout=channel_mlp_dropout, + n_dim=self.n_dim, + ) + for _ in range(n_layers) + ] + ) + if self.complex_data: + # self.channel_mlp = nn.ModuleList( + self.channel_mlp = nn.CellList( + [ComplexValued(x) for x in self.channel_mlp] + ) + # self.channel_mlp_skips = nn.ModuleList( + self.channel_mlp_skips = nn.CellList( + [ + skip_connection( + self.in_channels, + self.out_channels, + skip_type=channel_mlp_skip, + n_dim=self.n_dim, + ) + for _ in range(n_layers) + ] + ) + if self.complex_data: + # self.channel_mlp_skips = nn.ModuleList( + self.channel_mlp_skips = nn.CellList( + [ComplexValued(x) for x in self.channel_mlp_skips] + ) + + # Each block will have 2 norms if we also use a ChannelMLP + self.n_norms = 2 + if norm is None: + self.norm = None + elif norm == "instance_norm": + # self.norm = nn.ModuleList( + self.norm = nn.CellList( + [ + InstanceNorm() + for _ in range(n_layers * self.n_norms) + ] + ) + elif norm == "group_norm": + # self.norm = nn.ModuleList( + self.norm = nn.CellList( + [ + nn.GroupNorm(num_groups=1, num_channels=self.out_channels) + for _ in range(n_layers * self.n_norms) + ] + ) + + elif norm == "batch_norm": + # self.norm = nn.ModuleList( + self.norm = nn.CellList( + [ + BatchNorm(n_dim=self.n_dim, num_features=self.out_channels) + for _ in range(n_layers * self.n_norms) + ] + ) + + elif norm == "ada_in": + # self.norm = nn.ModuleList( + self.norm = nn.CellList( + [ + AdaIN(ada_in_features, out_channels) + for _ in range(n_layers * self.n_norms) + ] + ) + else: + raise ValueError( + f"Got norm={norm} but expected None or one of " + "[instance_norm, group_norm, batch_norm, ada_in]" + ) + + def set_ada_in_embeddings(self, *embeddings): + """Sets the embeddings of each Ada-IN norm layers + + Parameters + ---------- + embeddings : tensor or list of tensor + if a single embedding is given, it will be used for each norm layer + otherwise, each embedding will be used for the corresponding norm layer + """ + if self.norm is not None: + if len(embeddings) == 1: + for norm in self.norm: + norm.set_embedding(embeddings[0]) + else: + for norm, embedding in zip(self.norm, embeddings): + norm.set_embedding(embedding) + + # def forward(self, x, index=0, output_shape=None): + def construct(self, x, index=0, output_shape=None): + if self.preactivation: + return self.forward_with_preactivation(x, index, output_shape) + else: + return self.forward_with_postactivation(x, index, output_shape) + + def forward_with_postactivation(self, x, index=0, output_shape=None): + x_skip_fno = self.fno_skips[index](x) + x_skip_fno = self.convs[index].transform(x_skip_fno, output_shape=output_shape) + + if self.use_channel_mlp: + x_skip_channel_mlp = self.channel_mlp_skips[index](x) + x_skip_channel_mlp = self.convs[index].transform(x_skip_channel_mlp, output_shape=output_shape) + + if self.stabilizer == "tanh": + if self.complex_data: + x = ctanh(x) + else: + # x = torch.tanh(x) + x = mindspore.mint.tanh(x) + + x_fno = self.convs[index](x, output_shape=output_shape) + #self.convs(x, index, output_shape=output_shape) + + if self.norm is not None: + x_fno = self.norm[self.n_norms * index](x_fno) + + x = x_fno + x_skip_fno + + if (index < (self.n_layers - 1)): + x = self.non_linearity(x) + + if self.use_channel_mlp: + x = self.channel_mlp[index](x) + x_skip_channel_mlp + + if self.norm is not None: + x = self.norm[self.n_norms * index + 1](x) + + if index < (self.n_layers - 1): + x = self.non_linearity(x) + + return x + + def forward_with_preactivation(self, x, index=0, output_shape=None): + # Apply non-linear activation (and norm) + # before this block's convolution/forward pass: + x = self.non_linearity(x) + + if self.norm is not None: + x = self.norm[self.n_norms * index](x) + + x_skip_fno = self.fno_skips[index](x) + x_skip_fno = self.convs[index].transform(x_skip_fno, output_shape=output_shape) + + if self.use_channel_mlp: + x_skip_channel_mlp = self.channel_mlp_skips[index](x) + x_skip_channel_mlp = self.convs[index].transform(x_skip_channel_mlp, output_shape=output_shape) + + if self.stabilizer == "tanh": + if self.complex_data: + x = ctanh(x) + else: + # x = torch.tanh(x) + x = mindspore.mint.tanh(x) + + x_fno = self.convs[index](x, output_shape=output_shape) + + x = x_fno + x_skip_fno + + if index < (self.n_layers - 1): + x = self.non_linearity(x) + + if self.norm is not None: + x = self.norm[self.n_norms * index + 1](x) + + if self.use_channel_mlp: + x = self.channel_mlp[index](x) + x_skip_channel_mlp + + return x + + @property + def n_modes(self): + return self._n_modes + + @n_modes.setter + def n_modes(self, n_modes): + for i in range(self.n_layers): + self.convs[i].n_modes = n_modes + self._n_modes = n_modes + + def get_block(self, indices): + """Returns a sub-FNO Block layer from the jointly parametrized main block + + The parametrization of an FNOBlock layer is shared with the main one. + """ + if self.n_layers == 1: + raise ValueError( + "A single layer is parametrized, directly use the main class." + ) + + return SubModule(self, indices) + + def __getitem__(self, indices): + return self.get_block(indices) + + +# class SubModule(nn.Module): +class SubModule(nn.Cell): + """Class representing one of the sub_module from the mother joint module + + Notes + ----- + This relies on the fact that nn.Parameters are not duplicated: + if the same nn.Parameter is assigned to multiple modules, + they all point to the same data, which is shared. + """ + + def __init__(self, main_module, indices): + super().__init__() + self.main_module = main_module + self.indices = indices + + # def forward(self, x): + def construct(self, x): + # return self.main_module.forward(x, self.indices) + return self.main_module.construct(x, self.indices) \ No newline at end of file diff --git a/MindFlow/mindflow/GINO/layers/gno_weighting_functions.py b/MindFlow/mindflow/GINO/layers/gno_weighting_functions.py new file mode 100644 index 000000000..e25dd33e3 --- /dev/null +++ b/MindFlow/mindflow/GINO/layers/gno_weighting_functions.py @@ -0,0 +1,58 @@ +from functools import partial +# import torch +import mindspore +import math + +def bump_cutoff(x, radius=1., scale=1., eps=1e-7): + out = x.clip(0., radius) / radius + out = - 1 / ((1 - out ** 2) + eps) + # return out.exp() * torch.e * scale + return out.exp() * mindspore.Tensor(math.e) *scale + +def half_cos_cutoff(x, radius=1., scale=1.): + x = x / radius + # return scale * (0.5 * torch.cos(torch.pi * x) + 0.5) + return scale * (0.5 * mindspore.mint.cos(mindspore.Tensor(math.pi)* x) + 0.5) + +def quadr_cutoff(x, radius=1., scale=1.): + x = x / radius + left = 1 - 2 * x ** 2 + right = 2 * (1 - x) ** 2 + # return scale * torch.where(x < 0.5, left, right) + return scale * mindspore.mint.where(x < 0.5, left, right) + +def quartic_cutoff(x, radius=1., scale=1.): + a = scale / radius ** 4 + c = - 2 * scale / radius ** 2 + return a * x ** 4 + c * x ** 2 + scale + +def octic_cutoff(x, radius=1., scale=1.): + x = x / radius + return scale * (-3 * x ** 8 + 8 * x ** 6 - 6 * x ** 4 + 1) + +WEIGHTING_FN_REGISTRY = { + "bump": bump_cutoff, + "half_cos": half_cos_cutoff, + "quadr": quadr_cutoff, + "quartic": quartic_cutoff, + "octic": octic_cutoff, +} + +def dispatch_weighting_fn(weight_function_name : str, sq_radius: float, scale: float): + ''' + Select a GNO weighting function for use in output GNO + of a Mollified Graph Neural Operator-based model. See [1]_ (add later) + + Parameters + ---------- + weight_function_name : str Literal + name of weighting function to use, keyed to ``WEIGHTING_FN_REGISTRY`` above + sq_radius : float + squared radius of GNO neighborhoods for Nyström approximation + scale : float + factor by which to scale all weights + ''' + base_func = WEIGHTING_FN_REGISTRY.get(weight_function_name) + if base_func is None: + raise NotImplementedError(f"weighting function should be one of {list(WEIGHTING_FN_REGISTRY.keys())}, got {weight_function_name}") + return partial(base_func, radius=sq_radius, scale=scale) \ No newline at end of file diff --git a/MindFlow/mindflow/GINO/layers/integral_transform.py b/MindFlow/mindflow/GINO/layers/integral_transform.py new file mode 100644 index 000000000..c212221f5 --- /dev/null +++ b/MindFlow/mindflow/GINO/layers/integral_transform.py @@ -0,0 +1,243 @@ +# import torch +# from torch import nn +# import torch.nn.functional as F +import mindspore +from mindspore import nn +from mindspore import mint +import mindspore.mint.nn.functional as F + +from .channel_mlp import LinearChannelMLP +from .segment_csr import segment_csr + + +class IntegralTransform(nn.Cell): + """Integral Kernel Transform (GNO) + Computes one of the following: + (a) \\int_{A(x)} k(x, y) dy + (b) \\int_{A(x)} k(x, y) * f(y) dy + (c) \\int_{A(x)} k(x, y, f(y)) dy + (d) \\int_{A(x)} k(x, y, f(y)) * f(y) dy + + x : Points for which the output is defined + + y : Points for which the input is defined + A(x) : A subset of all points y (depending on\ + each x) over which to integrate + + k : A kernel parametrized as a MLP (LinearChannelMLP) + + f : Input function to integrate against given\ + on the points y + + If f is not given, a transform of type (a) + is computed. Otherwise transforms (b), (c), + or (d) are computed. The sets A(x) are specified + as a graph in CRS format. + + Parameters + ---------- + channel_mlp : mindspore.nn.Cell, default None + MLP parametrizing the kernel k. Input dimension + should be dim x + dim y or dim x + dim y + dim f. + MLP should not be pointwise and should only operate across + channels to preserve the discretization-invariance of the + kernel integral. + channel_mlp_layers : list, default None + List of layers sizes speficing a MLP which + parametrizes the kernel k. The MLP will be + instansiated by the LinearChannelMLP class + channel_mlp_non_linearity : callable, default mindspore.mint.nn.functional.gelu + Non-linear function used to be used by the + LinearChannelMLP class. Only used if channel_mlp_layers is + given and channel_mlp is None + transform_type : str, default 'linear' + Which integral transform to compute. The mapping is: + 'linear_kernelonly' -> (a) + 'linear' -> (b) + 'nonlinear_kernelonly' -> (c) + 'nonlinear' -> (d) + If the input f is not given then (a) is computed + by default independently of this parameter. + use_torch_scatter : bool, default 'True' + whether to use ``torch-scatter`` to perform grouped reductions in the ``IntegralTransform``. + If False, uses native Python reduction in ``neuralop.layers.segment_csr``, by default True + + .. warning:: + + ``torch-scatter`` is an optional dependency that conflicts with the newest versions of PyTorch, + so you must handle the conflict explicitly in your environment. See :ref:`torch_scatter_dependency` + for more information. + """ + + def __init__( + self, + channel_mlp=None, + channel_mlp_layers=None, + channel_mlp_non_linearity=F.gelu, + transform_type="linear", + weighting_fn=None, + reduction='sum', + use_torch_scatter=True, + ): + super().__init__() + + assert channel_mlp is not None or channel_mlp_layers is not None + + self.reduction = reduction + self.transform_type = transform_type + self.use_torch_scatter = use_torch_scatter + if ( + self.transform_type != "linear_kernelonly" + and self.transform_type != "linear" + and self.transform_type != "nonlinear_kernelonly" + and self.transform_type != "nonlinear" + ): + raise ValueError( + f"Got transform_type={transform_type} but expected one of " + "[linear_kernelonly, linear, nonlinear_kernelonly, nonlinear]" + ) + + if channel_mlp is None: + self.channel_mlp = LinearChannelMLP(layers=channel_mlp_layers, non_linearity=channel_mlp_non_linearity) + else: + self.channel_mlp = channel_mlp + + self.weighting_fn = weighting_fn + + def construct(self, y, neighbors, x=None, f_y=None, weights=None): + """Compute a kernel integral transform. Assumes x=y if not specified. + + Integral is taken w.r.t. the neighbors. + + If no weights are given, a Monte-Carlo approximation is made. + + .. note :: For transforms of type 0 or 2, out channels must be + the same as the channels of f + + Parameters + ---------- + y : mindspore.Tensor of shape [n, d1] + n points of dimension d1 specifying + the space to integrate over. + If batched, these must remain constant + over the whole batch so no batch dim is needed. + neighbors : dict + The sets A(x) given in CRS format. The + dict must contain the keys "neighbors_index" + and "neighbors_row_splits." For descriptions + of the two, see NeighborSearch. + If batch > 1, the neighbors must be constant + across the entire batch. + x : mindspore.Tensor of shape [m, d2], default None + m points of dimension d2 over which the + output function is defined. If None, + x = y. + f_y : mindspore.Tensor of shape [batch, n, d3] or [n, d3], default None + Function to integrate the kernel against defined + on the points y. The kernel is assumed diagonal + hence its output shape must be d3 for the transforms + (b) or (d). If None, (a) is computed. + weights : mindspore.Tensor of shape [n,], default None + Weights for each point y proprtional to the + volume around f(y) being integrated. For example, + suppose d1=1 and let y_1 < y_2 < ... < y_{n+1} + be some points. Then, for a Riemann sum, + the weights are y_{j+1} - y_j. If None, + 1/|A(x)| is used. + + Output + ---------- + out_features : mindspore.Tensor of shape [batch, m, d4] or [m, d4] + Output function given on the points x. + d4 is the output size of the kernel k. + """ + + if x is None: + x = y + + rep_features = y[neighbors["neighbors_index"]] + + # batching only matters if f_y (latent embedding) values are provided + batched = False + batch_size = None + in_features = None + # # f_y has a batch dim IFF batched=True + # if f_y is not None: + # if f_y.ndim == 3: + # batched = True + # batch_size = f_y.shape[0] + # in_features = f_y[:, neighbors["neighbors_index"], :] + # elif f_y.ndim == 2: + # batched = False + # in_features = f_y[neighbors["neighbors_index"]] + + num_reps = ( + neighbors["neighbors_row_splits"][1:] + - neighbors["neighbors_row_splits"][:-1] + ) + + # self_features = mint.repeat_interleave(x, num_reps, dim=0) + self_features = mindspore.numpy.repeat(x, num_reps, axis=0) + + agg_features = mint.cat([rep_features, self_features], dim=-1) + + # f_y has a batch dim IFF batched=True + if f_y is not None: + if f_y.ndim == 3: + batched = True + batch_size = f_y.shape[0] + in_features = f_y[:, neighbors["neighbors_index"], :] + elif f_y.ndim == 2: + in_features = f_y[neighbors["neighbors_index"]] + + if self.transform_type in ["nonlinear_kernelonly", "nonlinear"]: + if batched and batch_size is not None: + # repeat agg features for every example in the batch + agg_features = agg_features.repeat( + [batch_size] + [1] * agg_features.ndim + ) + agg_features = mint.cat([agg_features, in_features], dim=-1) + + # if f_y is not None and ( + # self.transform_type == "nonlinear_kernelonly" + # or self.transform_type == "nonlinear" + # ): + # if batched: + # # repeat agg features for every example in the batch + # agg_features = agg_features.repeat( + # [batch_size] + [1] * agg_features.ndim + # ) + # agg_features = mint.cat([agg_features, in_features], dim=-1) + + rep_features = self.channel_mlp(agg_features) + + if f_y is not None and self.transform_type != "nonlinear_kernelonly": + # if we have a batch of outputs (3d incl. batch dim) and unbatched inputs, + # create an identical batch dim in rep_features + if rep_features.ndim == 2 and batched: + rep_features = rep_features.unsqueeze(0).repeat([batch_size] + [1] * rep_features.ndim) + rep_features.mul_(in_features) + + # Weight neighbors in each neighborhood, first according to the neighbor search (mollified GNO) + # and second according to individually-provided weights. + nbr_weights = neighbors.get("weights") + if nbr_weights is None: + nbr_weights = weights + if nbr_weights is None and self.weighting_fn is not None: + raise KeyError("if a weighting function is provided, your neighborhoods must contain weights.") + if nbr_weights is not None: + nbr_weights = nbr_weights.unsqueeze(-1).unsqueeze(0) + if self.weighting_fn is not None: + nbr_weights = self.weighting_fn(nbr_weights) + rep_features.mul_(nbr_weights) + reduction = "sum" # Force sum reduction for weighted GNO layers + + else: + reduction = self.reduction + + splits = neighbors["neighbors_row_splits"] + if batched: + splits = splits.unsqueeze(0).repeat([batch_size] + [1] * (splits.ndim)) + + out_features = segment_csr(rep_features, splits, reduction=reduction, use_scatter=self.use_torch_scatter) + return out_features diff --git a/MindFlow/mindflow/GINO/layers/neighbor_search.py b/MindFlow/mindflow/GINO/layers/neighbor_search.py new file mode 100644 index 000000000..6cec917e6 --- /dev/null +++ b/MindFlow/mindflow/GINO/layers/neighbor_search.py @@ -0,0 +1,116 @@ +# import torch +# from torch import nn +import mindspore +from mindspore import nn +from mindspore import mint + +# only import open3d if built +open3d_built = False +# try: +# from open3d.ml.torch.layers import FixedRadiusSearch +# open3d_built = True +# except: +# pass + +# Uses open3d by default which, as of October 2024, requires torch 2.0 and cuda11.* +class NeighborSearch(nn.Cell): + """ + Neighborhood search between two arbitrary coordinate meshes. + For each point `x` in `queries`, returns a set of the indices of all points `y` in `data` + within the ball of radius r `B_r(x)` + + Parameters + ---------- + use_open3d : bool + Whether to use open3d or native PyTorch implementation + NOTE: open3d implementation requires 3d data + """ + def __init__(self, use_open3d=True, return_norm=False): + super().__init__() + # if use_open3d and open3d_built: # slightly faster, works on GPU in 3d only + # self.search_fn = FixedRadiusSearch() + # self.use_open3d = use_open3d + # else: # slower fallback, works on GPU and CPU + self.search_fn = native_neighbor_search + self.use_open3d = False + self.return_norm = return_norm + + + def construct(self, data, queries, radius): + """ + Find the neighbors, in data, of each point in queries + within a ball of radius. Returns in CRS format. + + Parameters + ---------- + data : mindspore.Tensor of shape [n, d] + Search space of possible neighbors + NOTE: open3d requires d=3 + queries : mindspore.Tensor of shape [m, d] + Points for which to find neighbors + NOTE: open3d requires d=3 + radius : float + Radius of each ball: B(queries[j], radius) + + Output + ---------- + return_dict : dict + Dictionary with keys: neighbors_index, neighbors_row_splits + neighbors_index: mindspore.Tensor with dtype=mindspore.int64 + Index of each neighbor in data for every point + in queries. Neighbors are ordered in the same orderings + as the points in queries. Open3d and torch_cluster + implementations can differ by a permutation of the + neighbors for every point. + neighbors_row_splits: mindspore.Tensor of shape [m+1] with dtype=mindspore.int64 + The value at index j is the sum of the number of + neighbors up to query point j-1. First element is 0 + and last element is the total number of neighbors. + """ + return_dict = {} + + if self.use_open3d: + search_return = self.search_fn(data, queries, radius) + return_dict['neighbors_index'] = search_return.neighbors_index.long() + return_dict['neighbors_row_splits'] = search_return.neighbors_row_splits.long() + + else: + return_dict = self.search_fn(data, queries, radius, self.return_norm) + + return return_dict + +def native_neighbor_search(data: mindspore.Tensor, queries: mindspore.Tensor, radius: float, return_norm: bool=False): + """ + Native PyTorch implementation of a neighborhood search + between two arbitrary coordinate meshes. + + Parameters + ----------- + + data : mindspore.Tensor + vector of data points from which to find neighbors + queries : mindspore.Tensor + centers of neighborhoods + radius : float + size of each neighborhood + """ + nbr_dict = {} + + # compute pairwise distances + all_dists = mint.cdist(queries, data) # shaped num query points x num data points + # keep zero-distance points + eps = 1e-7 + all_dists = mint.where(all_dists == 0., eps, all_dists) + dists = mint.where(all_dists <= radius, all_dists, 0.) # i,j is 1 if j is i's neighbor + nbr_indices = dists.nonzero()[:,1:].reshape(-1,) # only keep the column indices + if return_norm: + weights = dists[dists.nonzero(as_tuple=True)] + nbr_dict['weights'] = weights **2 # weighting function computed on squared norms + in_nbr = mint.where(dists > 0, 1., 0.,) + nbrhd_sizes = mint.cumsum(mint.sum(in_nbr, dim=1), dim=0) # num points in each neighborhood, summed cumulatively + splits = mint.cat((mindspore.tensor([0.]), nbrhd_sizes)) + + nbr_dict['neighbors_index'] = nbr_indices.long() + nbr_dict['neighbors_row_splits'] = splits.long() + + return nbr_dict \ No newline at end of file diff --git a/MindFlow/mindflow/GINO/layers/segment_csr.py b/MindFlow/mindflow/GINO/layers/segment_csr.py new file mode 100644 index 000000000..fdd41d58d --- /dev/null +++ b/MindFlow/mindflow/GINO/layers/segment_csr.py @@ -0,0 +1,100 @@ +from typing import Literal +import importlib + +import mindspore +from mindspore import mint + +def segment_csr( + src: mindspore.Tensor, + indptr: mindspore.Tensor, + reduction: Literal["mean", "sum"], + use_scatter=True, +): + """segment_csr reduces all entries of a CSR-formatted + matrix by summing or averaging over neighbors. + + Used to reduce features over neighborhoods + in neuralop.layers.IntegralTransform + + If use_scatter is set to False or torch_scatter is not + properly built, segment_csr falls back to a naive PyTorch implementation + + Note: the native version is mainly intended for running tests on + CPU-only GitHub CI runners to get around a versioning issue. + torch_scatter should be installed and built if possible. + + Parameters + ---------- + src : mindspore.Tensor + tensor of features for each point + indptr : mindspore.Tensor + splits representing start and end indices + of each neighborhood in src + reduce : Literal['mean', 'sum'], optional + how to reduce a neighborhood. if mean, + reduce by taking the average of all neighbors. + Otherwise take the sum. + use_scatter : bool, optional + whether to use ``torch-scatter.segment_csr``. If False, uses native Python reduction. + By default True + + .. warning:: + + ``torch-scatter`` is an optional dependency that conflicts with the newest versions of PyTorch, + so you must handle the conflict explicitly in your environment. See :ref:`torch_scatter_dependency` + for more information. + """ + if reduction not in ["mean", "sum"]: + raise ValueError("reduce must be one of 'mean', 'sum'") + + # if ( + # importlib.util.find_spec("torch_scatter") is not None + # and use_scatter + # ): + # """only import torch_scatter when cuda is available""" + # import torch_scatter.segment_csr as scatter_segment_csr + + # return scatter_segment_csr(src, indptr, reduce=reduction) + + # else: + # if use_scatter: + # print("Warning: use_scatter is True but torch_scatter is not properly built. \ + # Defaulting to naive PyTorch implementation") + # if batched, shape [b, n_reps, channels] + # otherwise shape [n_reps, channels] + if src.ndim == 3: + batched = True + point_dim = 1 + else: + batched = False + point_dim = 0 + + # if batched, shape [b, n_out, channels] + # otherwise shape [n_out, channels] + output_shape = list(src.shape) + n_out = indptr.shape[point_dim] - 1 + output_shape[point_dim] = n_out + + out = mint.zeros(output_shape) + + for i in range(n_out): + # reduce all indices pointed to in indptr from src into out + if batched: + from_idx = (slice(None), slice(indptr[0,i], indptr[0,i+1])) + ein_str = 'bio->bo' + start = indptr[0,i] + n_nbrs = indptr[0,i+1] - start + to_idx = (slice(None), i) + else: + from_idx = slice(indptr[i], indptr[i+1]) + ein_str = 'io->o' + start = indptr[i] + n_nbrs = indptr[i+1] - start + to_idx = i + src_from = src[from_idx] + if n_nbrs > 0: + to_reduce = mint.einsum(ein_str, src_from) + if reduction == "mean": + to_reduce /= n_nbrs + out[to_idx] += to_reduce + return out diff --git a/MindFlow/mindflow/GINO/layers/tests/test_assert_close.py b/MindFlow/mindflow/GINO/layers/tests/test_assert_close.py new file mode 100644 index 000000000..05b1bafce --- /dev/null +++ b/MindFlow/mindflow/GINO/layers/tests/test_assert_close.py @@ -0,0 +1,74 @@ +import mindspore +import unittest +from ..assert_close import assert_close + +class TestAssertClose(unittest.TestCase): + def test_basic_equal(self): + # 测试基本相等情况 + actual = mindspore.tensor([1.0, 2.0, 3.0]) + expected = mindspore.tensor([1.0, 2.0, 3.0]) + assert_close(actual, expected) + + def test_close_within_tolerance(self): + # 测试在容差范围内的情况 + actual = mindspore.tensor([1.00001, 2.0, 3.0]) + expected = mindspore.tensor([1.0, 2.0, 3.0]) + assert_close(actual, expected, rtol=1e-4, atol=1e-4) + + def test_not_close(self): + # 测试超出容差范围的情况 + actual = mindspore.tensor([1.01, 2.0, 3.0]) + expected = mindspore.tensor([1.0, 2.0, 3.0]) + with self.assertRaises(AssertionError) as cm: + assert_close(actual, expected, rtol=1e-4, atol=1e-4) + self.assertIn("张量不接近", str(cm.exception)) + + def test_different_shapes(self): + # 测试不同形状的情况 + actual = mindspore.tensor([1.0, 2.0, 3.0]) + expected = mindspore.tensor([[1.0, 2.0], [3.0, 4.0]]) + with self.assertRaises(AssertionError) as cm: + assert_close(actual, expected) + self.assertIn("张量形状不匹配", str(cm.exception)) + + def test_different_dtypes(self): + # 测试不同数据类型的情况 + actual = mindspore.tensor([1.0, 2.0, 3.0], dtype=mindspore.float32) + expected = mindspore.tensor([1.0, 2.0, 3.0], dtype=mindspore.float64) + with self.assertRaises(AssertionError) as cm: + assert_close(actual, expected) + self.assertIn("张量数据类型不匹配", str(cm.exception)) + + def test_equal_nan_true(self): + # 测试equal_nan=True的情况 + actual = mindspore.tensor([1.0, float('nan'), 3.0]) + expected = mindspore.tensor([1.0, float('nan'), 3.0]) + assert_close(actual, expected, equal_nan=True) + + def test_equal_nan_false(self): + # 测试equal_nan=False的情况 + actual = mindspore.tensor([1.0, float('nan'), 3.0]) + expected = mindspore.tensor([1.0, 2.0, 3.0]) + with self.assertRaises(AssertionError) as cm: + assert_close(actual, expected, equal_nan=False) + self.assertIn("张量中NaN存在性不一致", str(cm.exception)) + + def test_nan_different_positions(self): + # 测试NaN位置不同的情况 + actual = mindspore.tensor([1.0, float('nan'), 3.0]) + expected = mindspore.tensor([float('nan'), 2.0, 3.0]) + with self.assertRaises(AssertionError) as cm: + assert_close(actual, expected, equal_nan=True) + self.assertIn("张量中NaN位置不一致", str(cm.exception)) + + def test_check_stride(self): + # 测试步长检查 + x = mindspore.tensor([[1, 2], [3, 4]]) + actual = x.transpose(0, 1) + expected = mindspore.tensor([[1, 3], [2, 4]]) + with self.assertRaises(AssertionError) as cm: + assert_close(actual, expected, check_stride=True) + self.assertIn("张量步长不匹配", str(cm.exception)) + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/MindFlow/mindflow/GINO/layers/tests/test_gno_block.py b/MindFlow/mindflow/GINO/layers/tests/test_gno_block.py new file mode 100644 index 000000000..77595e334 --- /dev/null +++ b/MindFlow/mindflow/GINO/layers/tests/test_gno_block.py @@ -0,0 +1,116 @@ +# import torch +# from torch.autograd import grad +import mindspore +from mindspore import mint +import pytest +# from tensorly import tenalg +# tenalg.set_backend("einsum") + +# Parameterize use of torch_scatter if it is built +# try: +# from torch_scatter import segment_csr +# use_torch_scatter = [True, False] +# except: +use_torch_scatter = [False] + +from ..gno_block import GNOBlock + +# Fixed variables +in_channels = 3 +out_channels = 3 +mlp_hidden_layers = [16,16,16] + +# data parameters +n_in = 100 +n_out = 100 + +# test open3d mode if built +# try: +# from neighbor_search import FixedRadiusSearch +# open3d_built = True +# except: +open3d_built = False + +if open3d_built: + use_open3d_parametrize = [True, False] +else: + use_open3d_parametrize = [False] + +@pytest.mark.parametrize("batch_size", [1,4]) +@pytest.mark.parametrize("gno_coord_dim", [2,3]) +@pytest.mark.parametrize("gno_pos_embed_type", ['nerf', 'transformer', None]) +@pytest.mark.parametrize( + "gno_transform_type", ["linear", "nonlinear_kernelonly", "nonlinear"] +) +@pytest.mark.parametrize('use_open3d', use_open3d_parametrize) +@pytest.mark.parametrize('use_torch_scatter', use_torch_scatter) +def test_gno_block(gno_transform_type, gno_coord_dim, gno_pos_embed_type, batch_size, use_open3d, use_torch_scatter): + # if torch.backends.cuda.is_built(): + # device = torch.device("cuda:0") + # else: + # device = torch.device("cpu:0") + + use_open3d = use_open3d and (gno_coord_dim == 3) + + + gno_block = GNOBlock( + in_channels=in_channels, + out_channels=out_channels, # dummy var currently + coord_dim=gno_coord_dim, + pos_embedding_type=gno_pos_embed_type, + radius=0.25, + channel_mlp_layers=mlp_hidden_layers, + transform_type=gno_transform_type, + use_open3d_neighbor_search=use_open3d, + use_torch_scatter_reduce=use_torch_scatter + ) + + # create input geometry and output queries + input_geom_shape = [n_in, gno_coord_dim] + input_geom = mint.randn(*input_geom_shape) + + output_queries_shape = [n_out, gno_coord_dim] + output_queries = mint.randn(*output_queries_shape) + + f_y = None + if gno_transform_type != "linear": + # create data and features + f_y_shape = [batch_size, n_in, in_channels] + f_y = mint.randn(*f_y_shape) + # require and retain grad to check for backprop + f_y.requires_grad_(True) + + def forward_fn(y, x, f_y): + logits = gno_block(y=y, x=x, f_y=f_y) + assert logits.isfinite().all() + if batch_size > 1: + loss = logits[0].sum() + else: + loss = logits.sum() + return loss, logits + + # out = gno_block(y=input_geom, + # x=output_queries, + # f_y=f_y) + + grad_fn = mindspore.value_and_grad(forward_fn, None, gno_block.trainable_params(), has_aux=True) + (loss, out), grads = grad_fn(input_geom, output_queries, f_y) + # Check output size + # Batched outputs only matter in the nonlinear kernel use case + if gno_transform_type != "linear": + assert list(out.shape) == [batch_size, n_out, out_channels] + else: + assert list(out.shape) == [n_out, out_channels] + + # Check backward pass + # assert out.isfinite().all() + # if batch_size > 1: + # loss = out[0].sum() + # else: + # loss = out.sum() + + # loss.backward() + + if batch_size > 1 and gno_transform_type != "linear": + # assert f_y[1:] accumulates no grad if it's used + assert not f_y.grad[1:].nonzero().any() diff --git a/MindFlow/mindflow/GINO/layers/tests/test_grid_embeddings.py b/MindFlow/mindflow/GINO/layers/tests/test_grid_embeddings.py new file mode 100644 index 000000000..893a775df --- /dev/null +++ b/MindFlow/mindflow/GINO/layers/tests/test_grid_embeddings.py @@ -0,0 +1,45 @@ +import random + +from ..assert_close import assert_close +import pytest +import mindspore +from mindspore import mint + + +from ..embeddings import GridEmbedding2D, GridEmbeddingND + +# Testing grid-based pos encoding: choose a random grid +# point and assert the proper encoding is applied there + +def test_GridEmbedding2D(): + grid_boundaries = [[0,1], [0,1]] + pos_embed = GridEmbedding2D(in_channels=1, + grid_boundaries=grid_boundaries) + + input_res = (20,20) + x = mint.randn(1,1,*input_res) + x = pos_embed(x) + + index = [random.randint(0, res-1) for res in input_res] + true_coords = x[0,1:,index[0], index[1]].squeeze() # grab pos encoding channels at coord index + expected_coords = mindspore.tensor([i/j for i,j in zip(index,input_res)]) + assert_close(true_coords, expected_coords) + +@pytest.mark.parametrize('dim', [1,2,3,4]) +def test_GridEmbeddingND(dim): + grid_boundaries = [[0,1]] * dim + pos_embed = GridEmbeddingND(in_channels=1, + dim=dim, + grid_boundaries=grid_boundaries) + + input_res = [20] * dim + x = mint.randn(1,1,*input_res) + x = pos_embed(x) + + index = [random.randint(0, res-1) for res in input_res] + # grab pos encoding channels at coord index + pos_channels = x[0,1:,...] + indices = [slice(None), *index] + true_coords = pos_channels[indices] + expected_coords = mindspore.tensor([i/j for i,j in zip(index,input_res)]) + assert_close(true_coords, expected_coords) \ No newline at end of file diff --git a/MindFlow/mindflow/GINO/layers/tests/test_neighbor_search.py b/MindFlow/mindflow/GINO/layers/tests/test_neighbor_search.py new file mode 100644 index 000000000..0377b7b9c --- /dev/null +++ b/MindFlow/mindflow/GINO/layers/tests/test_neighbor_search.py @@ -0,0 +1,49 @@ +""" +Tests fallback neighbor search on a small 2d grid +that was calculated manually +""" + +import numpy as np +# import torch +import mindspore +from mindspore import mint +import pytest + +from ..neighbor_search import native_neighbor_search + +# Manually-calculated CSR list of neighbors +# in a 5x5 grid on [0,1] X [0,1] for radius=0.3 + +indices = [0, 1, 5, 0, 1, 2, 6, 1, 2, 3, 7, 2, 3, 4, 8, + 3, 4, 9, 0, 5, 6, 10, 1, 5, 6, 7, 11, 2, 6, 7, + 8, 12, 3, 7, 8, 9, 13, 4, 8, 9, 14, 5, 10, 11, + 15, 6, 10, 11, 12, 16, 7, 11, 12, 13, 17, 8, 12, + 13, 14, 18, 9, 13, 14, 19, 10, 15, 16, 20, 11, 15, + 16, 17, 21, 12, 16, 17, 18, 22, 13, 17, 18, 19, 23, + 14, 18, 19, 24, 15, 20, 21, 16, 20, 21, 22, 17, 21, + 22, 23, 18, 22, 23, 24, 19, 23, 24] + +splits = [0, 3, 7, 11, 15, 18, 22, 27, 32, 37, 41, 45, 50, + 55, 60, 64, 68, 73, 78, 83, 87, 90, 94, 98, 102, 105] + +def test_fallback_nb_search(): + mesh_grid = np.stack(np.meshgrid(*[np.linspace(0,1,5) for _ in range(2)], indexing="ij"), axis=-1) + coords = mindspore.Tensor(mesh_grid.reshape(-1,2)) # reshape into n**d x d coord points + return_dict = native_neighbor_search(data=coords, queries=coords, radius=0.3, return_norm=True) + + assert return_dict['neighbors_index'].tolist() == indices + assert return_dict['neighbors_row_splits'].tolist() == splits + print(f"{return_dict['weights']=}") + + def compute_norm_separate(nbrs, data, queries): + return_dict = nbrs + num_reps = return_dict['neighbors_row_splits'][1:] - return_dict['neighbors_row_splits'][:-1] + # rep_queries = mint.repeat_interleave(queries, num_reps, dim=0) + rep_queries = mindspore.numpy.repeat(queries, num_reps, axis=0) + rep_data = data[return_dict['neighbors_index']] + rep_dist = rep_queries - rep_data + return_dict['squared_norm'] = (rep_dist ** 2).sum(dim=-1) + return return_dict + + return_dict = compute_norm_separate(return_dict, coords, coords) + print(return_dict["squared_norm"]) \ No newline at end of file diff --git a/MindFlow/mindflow/GINO/layers/tests/test_segment_csr.py b/MindFlow/mindflow/GINO/layers/tests/test_segment_csr.py new file mode 100644 index 000000000..28afbf8b0 --- /dev/null +++ b/MindFlow/mindflow/GINO/layers/tests/test_segment_csr.py @@ -0,0 +1,47 @@ +# import torch +import mindspore +from mindspore import mint +from ..segment_csr import segment_csr + +import pytest + +@pytest.mark.parametrize('batch_size', [1,4]) +def test_native_segcsr_shapes(batch_size): + n_pts = 25 + n_channels = 5 + max_nbrhd_size = 7 # prevent degenerate cases in testing + + # tensor to reduce + src = mint.randn((batch_size, n_pts, n_channels)) + + # randomly generate index pointer tensor for CSR format + nbrhd_sizes = [mindspore.tensor([0])] + while sum(nbrhd_sizes) < n_pts: + nbrhd_sizes.append(mint.randint(0, int(max_nbrhd_size + 1), (1,))) + max_nbrhd_size = min(max_nbrhd_size, n_pts - sum(nbrhd_sizes)) + indptr = mint.cumsum(mindspore.tensor(nbrhd_sizes, dtype=mindspore.int64), dim=0) + if batch_size > 1: + indptr = indptr.repeat([batch_size] + [1]*indptr.ndim) + else: + src = src.squeeze(0) + out = segment_csr(src, indptr, reduction='sum', use_scatter=False) + + if batch_size == 1: + assert out.shape == (len(indptr) - 1, n_channels) + else: + assert out.shape == (batch_size, indptr.shape[1] - 1, n_channels) + +def test_native_segcsr_reductions(): + src = mint.ones([10, 3]) + indptr = mindspore.tensor([0,3,8,10], dtype=mindspore.int64) + + out_sum = segment_csr(src, indptr, reduction='sum', use_scatter=False) + assert out_sum.shape == (3,3) + diff = out_sum - mindspore.tensor([[3, 5, 2]]).T * mint.ones([3,3]) + assert not diff.nonzero().any() + + out_mean = segment_csr(src, indptr, reduction='mean', use_scatter=False) + assert out_mean.shape == (3,3) + diff = out_mean - mint.ones([3,3]) + assert not diff.nonzero().any() + \ No newline at end of file diff --git a/MindFlow/mindflow/GINO/layers/tests/test_sin_embeddings.py b/MindFlow/mindflow/GINO/layers/tests/test_sin_embeddings.py new file mode 100644 index 000000000..5fe62ae0d --- /dev/null +++ b/MindFlow/mindflow/GINO/layers/tests/test_sin_embeddings.py @@ -0,0 +1,96 @@ +import math + +from ..assert_close import assert_close +import pytest +import mindspore +from mindspore import mint + +from ..embeddings import SinusoidalEmbedding + +# Testing NeRF Embedding: start with a simple range +# and see that it is embedded properly + +batch_size = 4 +num_freqs = 3 +in_channels = 3 +n_in = 2 +max_pos = 10000 + + +def test_NeRFEmbedding(): + nerf_embed = SinusoidalEmbedding(in_channels=in_channels, + num_frequencies=3, + embedding_type='nerf') + unbatched_inputs = mint.arange(in_channels) * mindspore.tensor([[1.], [0.5]]) + embeds = nerf_embed(unbatched_inputs) + + true_outputs = mint.zeros((n_in, in_channels * num_freqs * 2)) + + # True values are (sin(2^0 * pi * p), cos(2^0 * pi * p), ... cos(2^(L-1) * pi * p)) + for channel in range(in_channels): + for wavenumber in range(num_freqs): + for i in range(2): + idx = channel * (num_freqs * 2) + wavenumber * 2 + i + freqs = 2 ** wavenumber * math.pi * unbatched_inputs[:, channel] + if i == 0: + true_outputs[:, idx] = freqs.sin() + else: + true_outputs[:, idx] = freqs.cos() + assert_close(embeds, true_outputs) + + batched_inputs = mint.stack([mint.arange(in_channels) * mindspore.tensor([[1.], [0.5]])] * batch_size) + embeds = nerf_embed(batched_inputs) + + true_outputs = mint.zeros((batch_size, n_in, in_channels * num_freqs * 2)) + + # True values are (sin(2^0 * pi * p), cos(2^0 * pi * p), ... cos(2^(L-1) * pi * p)) + for channel in range(in_channels): + for wavenumber in range(num_freqs): + for i in range(2): + idx = channel * (num_freqs * 2) + wavenumber * 2 + i + freqs = 2 ** wavenumber * math.pi * batched_inputs[:, :, channel] + if i == 0: + true_outputs[:, :, idx] = freqs.sin() + else: + true_outputs[:, :, idx] = freqs.cos() + assert_close(embeds, true_outputs) + + +def test_TransformerEmbedding(): + sin_embed = SinusoidalEmbedding(in_channels=in_channels, + num_frequencies=3, + embedding_type='transformer', + max_positions=max_pos) + unbatched_inputs = mint.arange(in_channels) * mindspore.tensor([[1.], [0.5]]) + embeds = sin_embed(unbatched_inputs) + + true_outputs = mint.zeros((n_in, in_channels * num_freqs * 2)) + + # True values are (sin(2^0 * pi * p), cos(2^0 * pi * p), ... cos(2^(L-1) * pi * p)) + for channel in range(in_channels): + for wavenumber in range(num_freqs): + for i in range(2): + idx = channel * (num_freqs * 2) + wavenumber * 2 + i + freqs = ((1 / max_pos) ** (wavenumber / (sin_embed.num_frequencies * 2))) * unbatched_inputs[:, channel] + if i == 0: + true_outputs[:, idx] = freqs.sin() + else: + true_outputs[:, idx] = freqs.cos() + assert_close(embeds, true_outputs) + + batched_inputs = mint.stack([mint.arange(in_channels) * mindspore.tensor([[1.], [0.5]])] * batch_size) + embeds = sin_embed(batched_inputs) + + true_outputs = mint.zeros((batch_size, n_in, in_channels * num_freqs * 2)) + + # True values are (sin(2^0 * pi * p), cos(2^0 * pi * p), ... cos(2^(L-1) * pi * p)) + for channel in range(in_channels): + for wavenumber in range(num_freqs): + for i in range(2): + idx = channel * (num_freqs * 2) + wavenumber * 2 + i + freqs = ((1 / max_pos) ** (wavenumber / (sin_embed.num_frequencies * 2))) * batched_inputs[:, :, channel] + if i == 0: + true_outputs[:, :, idx] = freqs.sin() + else: + true_outputs[:, :, idx] = freqs.cos() + assert_close(embeds, true_outputs) \ No newline at end of file diff --git a/MindFlow/mindflow/GINO/models/gino.py b/MindFlow/mindflow/GINO/models/gino.py new file mode 100644 index 000000000..8ecff793d --- /dev/null +++ b/MindFlow/mindflow/GINO/models/gino.py @@ -0,0 +1,533 @@ +from functools import partial +# import torch +# import torch.nn.functional as F + +import mindspore +import mindspore.mint.nn.functional as F + +import time + +from .base_model import BaseModel + +from ..layers.channel_mlp import ChannelMLP +from ..layers.embeddings import SinusoidalEmbedding +from ..layers.fno_block import FNOBlocks +from ..layers.spectral_convolution import SpectralConv +from ..layers.gno_block import GNOBlock +from ..layers.gno_weighting_functions import dispatch_weighting_fn + +class GINO(BaseModel): + """ + GINO: Geometry-informed Neural Operator. Learns a mapping between + functions presented over arbitrary coordinate meshes. The model carries + global integration through spectral convolution layers in an intermediate + latent space, as described in [1]_. Optionally enables a weighted output + GNO for use in a Mollified Graph Neural Operator scheme, as introduced in [2]_. + + Parameters + ---------- + in_channels : int + feature dimension of input points + out_channels : int + feature dimension of output points + latent_feature_channels : int, optional + number of channels in optional latent feature map + to concatenate onto latent embeddings before + the FNO's forward pass, default None + projection_channel_ratio : int, optional + ratio of pointwise projection channels in the final ``ChannelMLP`` + to ``fno_hidden_channels``, by default 4. The number of projection channels + in the final ``ChannelMLP`` is computed by + ``projection_channel_ratio * fno_hidden_channels`` (i.e. default 256) + gno_coord_dim : int, optional + geometric dimension of input/output queries, by default 3 + in_gno_radius : float, optional + radius in input space for GNO neighbor search, by default 0.033 + out_gno_radius : float, optional + radius in output space for GNO neighbor search, by default 0.033 + gno_weighting_function : Literal{'half_cos', 'bump', 'quartic', 'quadr', 'octic'}, optional + Choice of weighting function to use in the output GNO for + Mollified Graph Neural Operator-based models. + See ``neuralop.layers.gno_weighting_functions`` for more details. + gno_weight_function_scale : float, optional + Factor by which to scale weights from GNO weighting function + by default 1. + If ``gno_weighting_function`` is ``None``, this is not used. + in_gno_transform_type : str, optional + transform type parameter for input GNO, by default 'linear' + see neuralop.layers.gno_block for more details + out_gno_transform_type : str, optional + transform type parameter for output GNO, by default 'linear' + see neuralop.layers.gno_block for more details + in_gno_pos_embed_type : literal `{'transformer', 'nerf'}` | None + type of optional sinusoidal positional embedding to use in input GNOBlock, + by default `'transformer'` + out_gno_pos_embed_type : literal `{'transformer', 'nerf'}` | None + type of optional sinusoidal positional embedding to use in output GNOBlock, + by default `'transformer'` + fno_in_channels : int, optional + number of input channels for FNO, by default 3 + fno_n_modes : tuple, optional + number of modes along each dimension + to use in FNO, by default (16, 16, 16) + fno_hidden_channels : int, optional + hidden channels for use in FNO, by default 64 + fno_lifting_channel_ratio : int, optional + ratio of lifting channels to ``fno_hidden_channels``, by default 2 + The number of liting channels in the lifting block of the FNO is + fno_lifting_channel_ratio * hidden_channels (i.e. default 128) + fno_n_layers : int, optional + number of layers in FNO, by default 4 + + Other Parameters + ---------------- + gno_embed_channels: int + dimension of optional per-channel embedding to use in GNOBlock, + by default 32 + gno_embed_max_positions: int + max positions of optional per-channel embedding to use in GNOBlock, + by default 10000. If `gno_pos_embed_type != 'transformer'`, value is unused. + in_gno_channel_mlp_hidden_layers : list, optional + widths of hidden layers in input GNO, by default [80, 80, 80] + out_gno_channel_mlp_hidden_layers : list, optional + widths of hidden layers in output GNO, by default [512, 256] + gno_channel_mlp_non_linearity : nn.Module, optional + nonlinearity to use in gno ChannelMLP, by default F.gelu + gno_use_open3d : bool, optional + whether to use open3d neighbor search, by default True + if False, uses pure-PyTorch fallback neighbor search + gno_use_torch_scatter : bool, optional + whether to use ``torch-scatter`` to perform grouped reductions in the ``IntegralTransform``. + If False, uses native Python reduction in ``neuralop.layers.segment_csr``, by default True + + .. warning:: + + ``torch-scatter`` is an optional dependency that conflicts with the newest versions of PyTorch, + so you must handle the conflict explicitly in your environment. See :ref:`torch_scatter_dependency` + for more information. + out_gno_tanh : bool, optional + whether to use tanh to stabilize outputs of the output GNO, by default False + fno_resolution_scaling_factor : float | None, optional + factor by which to scale output of FNO, by default None + fno_incremental_n_modes : list[int] | None, defaults to None + if passed, sets n_modes separately for each FNO layer. + fno_block_precision : str, defaults to 'full' + data precision to compute within fno block + fno_use_channel_mlp : bool, defaults to True + Whether to use a ChannelMLP layer after each FNO block. + fno_channel_mlp_dropout : float, defaults to 0 + dropout parameter of above ChannelMLP. + fno_channel_mlp_expansion : float, defaults to 0.5 + expansion parameter of above ChannelMLP. + fno_non_linearity : nn.Module, defaults to F.gelu + nonlinear activation function between each FNO layer. + fno_stabilizer : nn.Module | None, defaults to None + By default None, otherwise tanh is used before FFT in the FNO block. + fno_norm : nn.Module | None, defaults to None + normalization layer to use in FNO. + fno_ada_in_features : int | None, defaults to 4 + if an adaptive mesh is used, number of channels of its positional embedding. + If None, adaptive mesh embedding is not used. + fno_ada_in_dim : int, defaults to 1 + dimensions of above FNO adaptive mesh. + fno_preactivation : bool, defaults to False + whether to use Resnet-style preactivation. + fno_skip : str, defaults to 'linear' + type of skip connection to use. + fno_channel_mlp_skip : str, defaults to 'soft-gating' + type of skip connection to use in the FNO + 'linear': conv layer + 'soft-gating': weights the channels of the input + 'identity': nn.Identity + fno_separable : bool, defaults to False + if True, use a depthwise separable spectral convolution. + fno_factorization : str {'tucker', 'tt', 'cp'} | None, defaults to None + Tensor factorization of the parameters weight to use + fno_rank : float, defaults to 1.0 + Rank of the tensor factorization of the Fourier weights. + fno_joint_factorization : bool, defaults to False + Whether all the Fourier layers should be parameterized by a single tensor (vs one per layer). + fno_fixed_rank_modes : bool, defaults to False + Modes to not factorize. + fno_implementation : str {'factorized', 'reconstructed'} | None, defaults to 'factorized' + If factorization is not None, forward mode to use:: + * `reconstructed` : the full weight tensor is reconstructed from the factorization and used for the forward pass + * `factorized` : the input is directly contracted with the factors of the decomposition + fno_decomposition_kwargs : dict, defaults to dict() + Optionaly additional parameters to pass to the tensor decomposition. + fno_conv_module : nn.Module, defaults to SpectralConv + Spectral Convolution module to use. + + + References + ----------- + .. [1] : Li, Z., Kovachki, N., Choy, C., Li, B., Kossaifi, J., Otta, S., + Nabian, M., Stadler, M., Hundt, C., Azizzadenesheli, K., Anandkumar, A. (2023) + Geometry-Informed Neural Operator for Large-Scale 3D PDEs. NeurIPS 2023, + https://proceedings.neurips.cc/paper_files/paper/2023/hash/70518ea42831f02afc3a2828993935ad-Abstract-Conference.html + .. [2] : Lin, R. et al. Placeholder reference for Mollified Graph Neural Operators. + """ + def __init__( + self, + in_channels, + out_channels, + latent_feature_channels=None, + projection_channel_ratio=4, + gno_coord_dim=3, + in_gno_radius=0.033, + out_gno_radius=0.033, + in_gno_transform_type='linear', + out_gno_transform_type='linear', + gno_weighting_function=None, + gno_weight_function_scale=1, + in_gno_pos_embed_type='transformer', + out_gno_pos_embed_type='transformer', + fno_in_channels=3, + fno_n_modes=(16, 16, 16), + fno_hidden_channels=64, + fno_lifting_channel_ratio=2, + fno_n_layers=4, + # Other GNO Params + gno_embed_channels=32, + gno_embed_max_positions=10000, + in_gno_channel_mlp_hidden_layers=[80, 80, 80], + out_gno_channel_mlp_hidden_layers=[512, 256], + gno_channel_mlp_non_linearity=F.gelu, + gno_use_open3d=True, + gno_use_torch_scatter=True, + out_gno_tanh=None, + # Other FNO Params + fno_resolution_scaling_factor=None, + fno_incremental_n_modes=None, + fno_block_precision='full', + fno_use_channel_mlp=True, + fno_channel_mlp_dropout=0, + fno_channel_mlp_expansion=0.5, + fno_non_linearity=F.gelu, + fno_stabilizer=None, + fno_norm=None, + fno_ada_in_features=4, + fno_ada_in_dim=1, + fno_preactivation=False, + fno_skip='linear', + fno_channel_mlp_skip='soft-gating', + fno_separable=False, + fno_factorization=None, + fno_rank=1.0, + fno_joint_factorization=False, + fno_fixed_rank_modes=False, + fno_implementation='factorized', + fno_decomposition_kwargs=dict(), + fno_conv_module=SpectralConv, + **kwargs + ): + + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.latent_feature_channels = latent_feature_channels + self.gno_coord_dim = gno_coord_dim + self.fno_hidden_channels = fno_hidden_channels + + self.lifting_channels = fno_lifting_channel_ratio * fno_hidden_channels + + # If the input GNO performs a nonlinear kernel, the GNO's output + # features must be the same dimension as its input. + # otherwise the kernel's MLP will perform a lifting operation to + # lift the inputs to ``fno_in_channels`` channels + if in_gno_transform_type in ["nonlinear", "nonlinear_kernelonly"]: + in_gno_out_channels = self.in_channels + else: + in_gno_out_channels = fno_in_channels + + # The actual input channels to the FNO are computed here. + self.fno_in_channels = in_gno_out_channels + + if latent_feature_channels is not None: + self.fno_in_channels += latent_feature_channels + + if self.gno_coord_dim != 3 and gno_use_open3d: + print(f'Warning: GNO expects {self.gno_coord_dim}-d data but Open3d expects 3-d data') + gno_use_open3d = False + + self.in_coord_dim = len(fno_n_modes) + self.gno_out_coord_dim = len(fno_n_modes) # gno output and fno will use same dimensions + if self.in_coord_dim != self.gno_coord_dim: + print(f'Warning: FNO expects {self.in_coord_dim}-d data while input GNO expects {self.gno_coord_dim}-d data') + + self.in_coord_dim_forward_order = list(range(self.in_coord_dim)) + + # tensor indices starting at 2 to permute everything after channel and batch dims + self.in_coord_dim_reverse_order = [j + 2 for j in self.in_coord_dim_forward_order] + + self.fno_norm = fno_norm + if self.fno_norm == "ada_in": + if fno_ada_in_features is not None and out_gno_pos_embed_type is not None: + self.adain_pos_embed = SinusoidalEmbedding(in_channels=fno_ada_in_dim, + num_frequencies=fno_ada_in_features, + max_positions=10000, + embedding_type=out_gno_pos_embed_type) + self.ada_in_dim = self.adain_pos_embed.out_channels + else: + self.ada_in_dim = fno_ada_in_dim + self.adain_pos_embed = None + else: + self.adain_pos_embed = None + self.ada_in_dim = None + + self.in_gno_radius = in_gno_radius + self.out_gno_radius = out_gno_radius + + self.out_gno_tanh = out_gno_tanh + + ### input GNO + # input to the first GNO ChannelMLP: `x` pos encoding, + # `y` (integrand) pos encoding, potentially `f_y` + + self.gno_in = GNOBlock( + in_channels=in_channels, + out_channels=in_gno_out_channels, + coord_dim=self.gno_coord_dim, + pos_embedding_type=in_gno_pos_embed_type, + pos_embedding_channels=gno_embed_channels, + pos_embedding_max_positions=gno_embed_max_positions, + radius=in_gno_radius, + reduction='mean', + weighting_fn=None, + channel_mlp_layers=in_gno_channel_mlp_hidden_layers, + channel_mlp_non_linearity=gno_channel_mlp_non_linearity, + transform_type=in_gno_transform_type, + use_torch_scatter_reduce=gno_use_torch_scatter, + use_open3d_neighbor_search=gno_use_open3d, + ) + + ### Lifting layer before FNOBlocks + self.lifting = ChannelMLP(in_channels=self.fno_in_channels, + hidden_channels=self.lifting_channels, + out_channels=fno_hidden_channels, + n_layers=2) # CHANGED RECENTLY FOR THIS PAPER + + ### FNOBlocks in latent space + # input: `in_p` intermediate embeddings, + # possibly concatenated feature channels `latent_features` + self.fno_blocks = FNOBlocks( + n_modes=fno_n_modes, + hidden_channels=fno_hidden_channels, + in_channels=fno_hidden_channels, + out_channels=fno_hidden_channels, + positional_embedding=None, + n_layers=fno_n_layers, + resolution_scaling_factor=fno_resolution_scaling_factor, + incremental_n_modes=fno_incremental_n_modes, + fno_block_precision=fno_block_precision, + use_channel_mlp=fno_use_channel_mlp, + channel_mlp_expansion=fno_channel_mlp_expansion, + channel_mlp_dropout=fno_channel_mlp_dropout, + non_linearity=fno_non_linearity, + stabilizer=fno_stabilizer, + norm=fno_norm, + ada_in_features=self.ada_in_dim, + preactivation=fno_preactivation, + fno_skip=fno_skip, + channel_mlp_skip=fno_channel_mlp_skip, + separable=fno_separable, + factorization=fno_factorization, + rank=fno_rank, + joint_factorization=fno_joint_factorization, + fixed_rank_modes=fno_fixed_rank_modes, + implementation=fno_implementation, + decomposition_kwargs=fno_decomposition_kwargs, + domain_padding=None, + domain_padding_mode=None, + conv_module=fno_conv_module, + **kwargs + ) + + ### output GNO + if gno_weighting_function is not None: #sq radius**2? + weight_fn = dispatch_weighting_fn(gno_weighting_function, sq_radius=out_gno_radius**2, scale=gno_weight_function_scale) + else: + weight_fn = None + self.gno_out = GNOBlock( + in_channels=fno_hidden_channels, # number of channels in f_y + out_channels=fno_hidden_channels, + coord_dim=self.gno_coord_dim, + radius=self.out_gno_radius, + reduction='sum', + weighting_fn=weight_fn, + pos_embedding_type=out_gno_pos_embed_type, + pos_embedding_channels=gno_embed_channels, + pos_embedding_max_positions=gno_embed_max_positions, + channel_mlp_layers=out_gno_channel_mlp_hidden_layers, + channel_mlp_non_linearity=gno_channel_mlp_non_linearity, + transform_type=out_gno_transform_type, + use_torch_scatter_reduce=gno_use_torch_scatter, + use_open3d_neighbor_search=gno_use_open3d, + ) + + projection_channels = projection_channel_ratio * fno_hidden_channels + self.projection = ChannelMLP(in_channels=fno_hidden_channels, + out_channels=self.out_channels, + hidden_channels=projection_channels, + n_layers=2, + n_dim=1, + non_linearity=fno_non_linearity) + + #returns: (fno_hidden_channels, n_1, n_2, ...) + def latent_embedding(self, in_p, ada_in=None): + + # in_p : (batch, n_1 , ... , n_k, in_channels + k) + # ada_in : (fno_ada_in_dim, ) + + # permute (b, n_1, ..., n_k, c) -> (b,c, n_1,...n_k) + in_p = in_p.permute(0, len(in_p.shape)-1, *list(range(1,len(in_p.shape)-1))) + #Update Ada IN embedding + if ada_in is not None: + if ada_in.ndim == 2: + ada_in = ada_in.squeeze(0) + if self.adain_pos_embed is not None: + ada_in_embed = self.adain_pos_embed(ada_in.unsqueeze(0)).squeeze(0) + else: + ada_in_embed = ada_in + if self.fno_norm == "ada_in": + self.fno_blocks.set_ada_in_embeddings(ada_in_embed) + + #Apply FNO blocks + in_p = self.lifting(in_p) + + for idx in range(self.fno_blocks.n_layers): + in_p = self.fno_blocks(in_p, idx) + + return in_p + + # def forward(self, input_geom, latent_queries, output_queries, x=None, latent_features=None, ada_in=None, **kwargs): + def construct(self, input_geom, latent_queries, output_queries, x=None, latent_features=None, ada_in=None, **kwargs): + """The GINO's forward call: + Input GNO --> FNOBlocks --> output GNO + projection to output queries. + + .. note :: + GINO currently supports batching **only in cases where the geometry of + inputs and outputs is shared across the entire batch**. Inputs can have a batch dim + in ``x`` and ``latent_features``, but it must be shared for both. + + Parameters + ---------- + input_geom : torch.Tensor + input domain coordinate mesh + shape (1, n_in, gno_coord_dim) + latent_queries : torch.Tensor + latent geometry on which to compute FNO latent embeddings + a grid on [0,1] x [0,1] x .... + shape (1, n_gridpts_1, .... n_gridpts_n, gno_coord_dim) + output_queries : torch.Tensor | dict[torch.Tensor] + points at which to query the final GNO layer to get output. + + shape (1, n_out, gno_coord_dim) per tensor. + + * if a tensor, the model will output a tensor. + + * if a dict of tensors, the model will return a dict of outputs, so + that ``output[key]`` corresponds to the model queried at + ``output_queries[key]``. + x : torch.Tensor, optional + input function a defined on the input domain `input_geom` + shape (batch, n_in, in_channels). Default None + latent_features : torch.Tensor, optional + optional feature map to concatenate onto latent embedding + before being passed into the latent FNO, default None + if `latent_feature_channels` is set, must be passed + ada_in : torch.Tensor, optional + adaptive scalar instance parameter, defaults to None + + Returns + ------- + out : torch.Tensor | dict[torch.Tensor] + Function over the output query coordinates + * tensor if if ``output_queries`` is a tensor + * dict if if ``output_queries`` is a dict + """ + + # Ensure input functions on the input geom and latent geom + # have compatible batch sizes + if x is None: + batch_size = 1 + else: + batch_size = x.shape[0] + + if latent_features is not None: + assert self.latent_feature_channels is not None,\ + "if passing latent features, latent_feature_channels must be set." + assert latent_features.shape[-1] == self.latent_feature_channels + + # batch, n_gridpts_1, .... n_gridpts_n, gno_coord_dim + assert latent_features.ndim == self.gno_coord_dim + 2,\ + f"Latent features must be of shape (batch, n_gridpts_1, ...n_gridpts_n, gno_coord_dim), got {latent_features.shape}" + # latent features must have the same shape (except channels) as latent_queries + if latent_features.shape[0] != batch_size: + if latent_features.shape[0] == 1: + latent_features = latent_features.repeat(batch_size, *[1]*(latent_features.ndim-1)) + + input_geom = input_geom.squeeze(0) + latent_queries = latent_queries.squeeze(0) + + # Pass through input GNOBlock + in_p = self.gno_in(y=input_geom, + x=latent_queries.view((-1, latent_queries.shape[-1])), + f_y=x) + + grid_shape = latent_queries.shape[:-1] # disregard positional encoding dim + + # shape (batch_size, grid1, ...gridn, -1) + in_p = in_p.view((batch_size, *grid_shape, -1)) + + if latent_features is not None: + # in_p = torch.cat((in_p, latent_features), dim=-1) + in_p = mindspore.mint.cat((in_p, latent_features), dim=-1) + # take apply fno in latent space + latent_embed = self.latent_embedding(in_p=in_p, + ada_in=ada_in) + + # Integrate latent space to output queries + #latent_embed shape (b, c, n_1, n_2, ..., n_k) + batch_size = latent_embed.shape[0] + # permute to (b, n_1, n_2, ...n_k, c) + # then reshape to (b, n_1 * n_2 * ...n_k, out_channels) + latent_embed = latent_embed.permute(0, *self.in_coord_dim_reverse_order, 1).reshape(batch_size, -1, self.fno_hidden_channels) + + if self.out_gno_tanh in ['latent_embed', 'both']: + # latent_embed = torch.tanh(latent_embed) + latent_embed = mindspore.mint.tanh(latent_embed) + + + # integrate over the latent space + # if output queries is a dict, query the output gno separately + # with each tensor of query points + if isinstance(output_queries, dict): + out = {} + for key, out_p in output_queries.items(): + out_p = out_p.squeeze(0) + + sub_output = self.gno_out(y=latent_queries.reshape((-1, latent_queries.shape[-1])), + x=out_p, + f_y=latent_embed,) + sub_output = sub_output.permute(0, 2, 1) + + # Project pointwise to out channels + #(b, n_in, out_channels) + sub_output = self.projection(sub_output).permute(0, 2, 1) + + out[key] = sub_output + else: + output_queries = output_queries.squeeze(0) + + # latent queries is of shape (d_1 x d_2 x... d_n x n), reshape to n_out x n + out = self.gno_out(y=latent_queries.reshape((-1, latent_queries.shape[-1])), + x=output_queries, + f_y=latent_embed,) + out = out.permute(0, 2, 1) + + # Project pointwise to out channels + #(b, n_in, out_channels) + out = self.projection(out).permute(0, 2, 1) + + return out \ No newline at end of file -- Gitee