diff --git a/MindEarth/applications/medium-range/graphcast/graphcast_tp.ipynb b/MindEarth/applications/medium-range/graphcast/graphcast_tp.ipynb index e98b680cce40b2116b34158e93de38e45e240242..f047ef947bcd3e3798a7ef06c815fc356c77d46a 100644 --- a/MindEarth/applications/medium-range/graphcast/graphcast_tp.ipynb +++ b/MindEarth/applications/medium-range/graphcast/graphcast_tp.ipynb @@ -43,8 +43,7 @@ "import numpy as np\n", "import matplotlib.pyplot as plt\n", "\n", - "from mindspore import set_seed\n", - "from mindspore import context, ops\n" + "from mindspore import set_seed, context, ops, nn" ] }, { @@ -94,7 +93,7 @@ "metadata": {}, "outputs": [], "source": [ - "config = load_yaml_config(\"./GraphCastTp.yaml\")\n", + "config = load_yaml_config(\"./configs/GraphCastTp.yaml\")\n", "context.set_context(mode=context.GRAPH_MODE, device_target=\"Ascend\", device_id=5)" ] }, @@ -194,7 +193,8 @@ "sj_std, wj, ai = get_coe(config)\n", "data_params = config.get('data')\n", "loss_fn = LossNet(ai, wj, sj_std, data_params.get('feature_dims'), data_params['tp'])\n", - "loss_cell = CustomWithLossCell(backbone=model, loss_fn=loss_fn, data_params=data_params)" + "loss_cell = CustomWithLossCell(backbone=model, loss_fn=loss_fn, data_params=data_params)\n", + "loss_scale = nn.DynamicLossScaleUpdateCell(loss_scale_value=2 ** 12, scale_factor=2, scale_window=1000)" ] }, { @@ -203,7 +203,7 @@ "metadata": {}, "outputs": [], "source": [ - "trainer = GraphCastTrainerTp(config, model, loss_cell, logger)\n", + "trainer = GraphCastTrainerTp(config, model, loss_cell, logger, loss_scale)\n", "trainer.train()" ] }, diff --git a/MindEarth/applications/medium-range/graphcast/graphcast_tp_CN.ipynb b/MindEarth/applications/medium-range/graphcast/graphcast_tp_CN.ipynb index b0d9556a0b7c222db30b05c3709f161d579657af..97ae60ba6d510266f63724ca41f1b9475d6a2bda 100644 --- a/MindEarth/applications/medium-range/graphcast/graphcast_tp_CN.ipynb +++ b/MindEarth/applications/medium-range/graphcast/graphcast_tp_CN.ipynb @@ -42,8 +42,7 @@ "import numpy as np\n", "import matplotlib.pyplot as plt\n", "\n", - "from mindspore import set_seed\n", - "from mindspore import context, ops\n" + "from mindspore import set_seed, context, ops, nn\n" ] }, { @@ -93,7 +92,7 @@ "metadata": {}, "outputs": [], "source": [ - "config = load_yaml_config(\"./GraphCastTp.yaml\")\n", + "config = load_yaml_config(\"./configs/GraphCastTp.yaml\")\n", "context.set_context(mode=context.GRAPH_MODE, device_target=\"Ascend\", device_id=5)" ] }, @@ -190,7 +189,8 @@ "sj_std, wj, ai = get_coe(config)\n", "data_params = config.get('data')\n", "loss_fn = LossNet(ai, wj, sj_std, data_params.get('feature_dims'), data_params['tp'])\n", - "loss_cell = CustomWithLossCell(backbone=model, loss_fn=loss_fn, data_params=data_params)" + "loss_cell = CustomWithLossCell(backbone=model, loss_fn=loss_fn, data_params=data_params)\n", + "loss_scale = nn.DynamicLossScaleUpdateCell(loss_scale_value=2 ** 12, scale_factor=2, scale_window=1000)" ] }, { @@ -199,7 +199,7 @@ "metadata": {}, "outputs": [], "source": [ - "trainer = GraphCastTrainerTp(config, model, loss_cell, logger)\n", + "trainer = GraphCastTrainerTp(config, model, loss_cell, logger, loss_scale)\n", "trainer.train()" ] }, diff --git a/MindEarth/applications/medium-range/koopman_vit/src/callback.py b/MindEarth/applications/medium-range/koopman_vit/src/callback.py index c6a865213c4507533ef40971c156aadb4c370eb8..f8802248e95e06d5de6ed611f5435e96bc72d9e0 100644 --- a/MindEarth/applications/medium-range/koopman_vit/src/callback.py +++ b/MindEarth/applications/medium-range/koopman_vit/src/callback.py @@ -123,7 +123,7 @@ class InferenceModule(WeatherForecast): """ def __init__(self, model, config, logger): - super(InferenceModule, self).__init__() + super().__init__(model, config, logger) self.model = model self.config = config self.logger = logger diff --git a/MindEarth/applications/medium-range/skno/main.py b/MindEarth/applications/medium-range/skno/main.py index c0a2cc93000f321c33a1de409b34e5252beb7965..8ab61561c38584e423ba1a3712052f40e6b355de 100644 --- a/MindEarth/applications/medium-range/skno/main.py +++ b/MindEarth/applications/medium-range/skno/main.py @@ -37,7 +37,7 @@ random.seed(0) def get_args(): """Get user specified parameters.""" parser = argparse.ArgumentParser(description='SKNO') - parser.add_argument('--config_file_path', type=str, default='/configs/skno.yaml') + parser.add_argument('--config_file_path', type=str, default='./configs/skno.yaml') parser.add_argument('--device_target', '-d', type=str, choices=["Ascend", "GPU"], default="Ascend") parser.add_argument("--mode", type=str, default="GRAPH", choices=["GRAPH", "PYNATIVE"], diff --git a/MindEarth/applications/medium-range/skno/src/callback.py b/MindEarth/applications/medium-range/skno/src/callback.py index 8aa7ce2784b96d302ea96137bf6ad24744c23f5a..d7a20d0f53b514c78367a3e8a5b45e7869112b05 100644 --- a/MindEarth/applications/medium-range/skno/src/callback.py +++ b/MindEarth/applications/medium-range/skno/src/callback.py @@ -132,7 +132,7 @@ class InferenceModule(WeatherForecast): """ def __init__(self, model, config, logger): - super(InferenceModule, self).__init__() + super().__init__(model, config, logger) self.model = model self.config = config self.logger = logger diff --git a/MindEarth/mindearth/cell/graphcast/graphcast.py b/MindEarth/mindearth/cell/graphcast/graphcast.py index d93a30c7da0eed244b6a4454248dfe48af7a1788..56acace187627d9db3f5edeaedddbd9e8aac11ae 100644 --- a/MindEarth/mindearth/cell/graphcast/graphcast.py +++ b/MindEarth/mindearth/cell/graphcast/graphcast.py @@ -25,26 +25,6 @@ set_seed(0) np.random.seed(0) -class GraphCastSiLU(nn.Cell): - r""" - A self-defined SwiGlu. - - Inputs: - - **x** (Tensor) - Tensor. - - Outputs: - Tensor. x = x * sigmod(x). - """ - def __init__(self): - super().__init__() - self.sigmoid = nn.Sigmoid() - self.mul = ops.Mul() - - def construct(self, x): - """GraphCastSiLU forward function.""" - return self.mul(x, self.sigmoid(x)) - - class MLPNet(nn.Cell): """ The MLPNet Network. Applies a series of fully connected layers to the incoming data among which hidden layers have @@ -59,7 +39,6 @@ class MLPNet(nn.Cell): Inputs: - **input** (Tensor) - Tensor of shape :math:`(*, dims[0]) - Outputs: - **output** (Tensor) - Tensor of shape :math:`(*, dims[-1]) @@ -76,21 +55,26 @@ class MLPNet(nn.Cell): (2, 8) """ + def __init__(self, in_channels, out_channels, - latent_dims): + latent_dims, + has_layernorm=True): super(MLPNet, self).__init__() - self.dense_in = nn.Dense(in_channels, - latent_dims, - has_bias=False, - activation=None) - self.silu = GraphCastSiLU() - self.dense_out = nn.Dense(latent_dims, - out_channels, - has_bias=False, - activation=None) - self.layer_norm = nn.LayerNorm([out_channels]) + cell_list = [nn.Dense(in_channels, + latent_dims, + has_bias=False, + activation=None), + nn.SiLU(), + nn.Dense(latent_dims, + out_channels, + has_bias=False, + activation=None), + ] + if has_layernorm: + cell_list.append(nn.LayerNorm([out_channels])) + self.network = nn.SequentialCell(cell_list) def construct(self, x: Tensor): '''MLPNet forward function @@ -98,11 +82,7 @@ class MLPNet(nn.Cell): Args: x (Tensor): Input Tensor. ''' - x = self.dense_in(x) - x = self.silu(x) - x = self.dense_out(x) - x = self.layer_norm(x) - return x + return self.network(x) class Embedder(nn.Cell): @@ -351,7 +331,8 @@ class Decoder(nn.Cell): dst_idx) self.node_fn = MLPNet(in_channels=node_in_channels, out_channels=node_final_dims, - latent_dims=latent_dims) + latent_dims=latent_dims, + has_layernorm=False) def construct(self, m2g_edge_feats, mesh_node_feats, grid_node_feats): '''Decoder forward function''' diff --git a/MindEarth/mindearth/module/forecast.py b/MindEarth/mindearth/module/forecast.py index 5c507ac0305e766e9755eabb07244f5502157050..4744ef650dfc982111fa2f063051406bd172c79c 100644 --- a/MindEarth/mindearth/module/forecast.py +++ b/MindEarth/mindearth/module/forecast.py @@ -161,7 +161,10 @@ class WeatherForecast: self.feature_dims = config['data'].get('feature_dims', 69) self.total_std = self._get_total_sample_description(config, "std") self.total_mean = self._get_total_sample_description(config, "mean") - self.climate_mean = self._get_history_climate_mean(config, self.w_size, self.adjust_size) + if config['model']['name'] == "GraphCastTp": + self.climate_mean = self._get_history_climate_mean(config) + else: + self.climate_mean = self._get_history_climate_mean(config, self.w_size, self.adjust_size) self.run_mode = config['train'].get("run_mode", 'train') if self.run_mode == 'train': self.t_out = config['data'].get("t_out_valid", 20) @@ -186,7 +189,7 @@ class WeatherForecast: return total_sample_info @staticmethod - def _get_history_climate_mean(config, w_size, adjust_size=False): + def _get_history_climate_mean(config, w_size=None, adjust_size=False): """get history climate mean.""" data_params = config.get('data') climate_mean = np.load(os.path.join(data_params.get("root_dir"), "statistic",