From c6518e754220093bb08709c51915f17e5e0bddf7 Mon Sep 17 00:00:00 2001 From: lry Date: Thu, 5 Sep 2024 10:18:24 +0800 Subject: [PATCH] fix mindearthAPI doc --- MindEarth/mindearth/data/dataset.py | 3 +++ MindEarth/mindearth/module/forecast.py | 2 ++ MindEarth/mindearth/module/pretrain.py | 12 +++++++++++- .../mindearth/data/mindearth.data.Era5Data.rst | 1 + .../module/mindearth.module.Trainer.rst | 15 ++++++++++++++- .../mindearth.module.WeatherForecast.rst | 18 +++++++++++++++++- 6 files changed, 48 insertions(+), 3 deletions(-) diff --git a/MindEarth/mindearth/data/dataset.py b/MindEarth/mindearth/data/dataset.py index 88ef3d108..761da0de7 100644 --- a/MindEarth/mindearth/data/dataset.py +++ b/MindEarth/mindearth/data/dataset.py @@ -132,6 +132,9 @@ class Era5Data(Data): data_params (dict): dataset-related configuration of the model. run_mode (str, optional): whether the dataset is used for training, evaluation or testing. Supports [“train”, “test”, “valid”]. Default: 'train'. + kno_patch (bool, optional): Indicates whether the data is already partitioned into patches. If True, the data + is assumed to be pre-processed and no further patching is performed. If False, the data will be processed + into patches as per the specified parameters. Default: False. Supported Platforms: ``Ascend`` ``GPU`` diff --git a/MindEarth/mindearth/module/forecast.py b/MindEarth/mindearth/module/forecast.py index 95014dc48..2af4ea427 100644 --- a/MindEarth/mindearth/module/forecast.py +++ b/MindEarth/mindearth/module/forecast.py @@ -214,6 +214,8 @@ class WeatherForecast: Args: dataset (mindspore.dataset): The dataset for eval, including inputs and labels. + generator_flag: "generator_flag" is used to pass a parameter to the "compute_total_rmse_acc" method. + A flag indicating whether to use a data generator or not. ''' data_length = len(dataset) // self.batch_size self.logger.info("================================Start Evaluation================================") diff --git a/MindEarth/mindearth/module/pretrain.py b/MindEarth/mindearth/module/pretrain.py index f48f845b6..80a4b8777 100644 --- a/MindEarth/mindearth/module/pretrain.py +++ b/MindEarth/mindearth/module/pretrain.py @@ -287,7 +287,17 @@ class Trainer: return solver def train(self): - """train.""" + """ + Train the model using the specified dataset and parameters. + + Initializes and configures the training process by setting up callbacks for monitoring + loss, time, and optionally predictions and checkpoints. The training is executed based + on the current training step, which may adjust the number of epochs or other parameters + for fine-tuning. + + Returns: + None, but modifies the model's state by training it. + """ callback_lst = [LossMonitor(), TimeMonitor()] if self.pred_cb: callback_lst.append(self.pred_cb) diff --git a/docs/api_python/mindearth/data/mindearth.data.Era5Data.rst b/docs/api_python/mindearth/data/mindearth.data.Era5Data.rst index 4df272e46..f5914b734 100644 --- a/docs/api_python/mindearth/data/mindearth.data.Era5Data.rst +++ b/docs/api_python/mindearth/data/mindearth.data.Era5Data.rst @@ -8,5 +8,6 @@ mindearth.data.Era5Data 参数: - **data_params** (dict) - 模型中的相关数据参数。 - **run_mode** (str, 可选) - 决定数据集用于训练、验证还是测试。支持 ``'train'``, ``'test'``, ``'valid'``。默认值: ``'train'``。 + - **kno_patch** (bool, 可选) - 决定数据集是否分割成小块。默认值:False。 diff --git a/docs/api_python/mindearth/module/mindearth.module.Trainer.rst b/docs/api_python/mindearth/module/mindearth.module.Trainer.rst index 71aca251d..a25a21ebb 100644 --- a/docs/api_python/mindearth/module/mindearth.module.Trainer.rst +++ b/docs/api_python/mindearth/module/mindearth.module.Trainer.rst @@ -56,4 +56,17 @@ mindearth.module.Trainer .. py:method:: mindearth.module.Trainer.train() - 执行模型训练。 \ No newline at end of file + 执行模型训练。 + + .. py:method:: mindearth.module.Trainer.get_data_generator() + + 生成用于训练和验证数据集的数据生成器。 + + 该函数根据指定的天气数据源创建数据生成器。 + 支持 'ERA5' 和 'DemSR' 数据源,对于不支持的数据源将引发错误。 + + 返回: + - 包含训练和验证数据生成器的元组。 + + 引发: + - NotImplementedError: 如果指定了不支持的数据源。 \ No newline at end of file diff --git a/docs/api_python/mindearth/module/mindearth.module.WeatherForecast.rst b/docs/api_python/mindearth/module/mindearth.module.WeatherForecast.rst index b50d42051..e79483efe 100644 --- a/docs/api_python/mindearth/module/mindearth.module.WeatherForecast.rst +++ b/docs/api_python/mindearth/module/mindearth.module.WeatherForecast.rst @@ -22,7 +22,8 @@ mindearth.module.WeatherForecast 参数: - **dataset** (mindspore.dataset) - 模型推理数据集,包括输入值和样本值。 - + - **generator_flag** (bool,可选) - 用于向 "compute_total_rmse_acc" 方法传递一个参数。 + 指示是否使用数据生成器。 .. py:method:: mindearth.module.WeatherForecast.forecast(inputs, labels=None) :staticmethod: @@ -31,4 +32,19 @@ mindearth.module.WeatherForecast 参数: - **inputs** (Tensor) - 模型的输入数据。 - **labels** (Tensor) - 样本真实数据。默认值: ``None``。 + .. py:method:: mindearth.module.WeatherForecast.compute_total_rmse_acc(dataset, generator_flag) + + 计算数据集的总体均方根误差(RMSE)和准确率。 + + 该函数遍历数据集,为每个批次计算RMSE和准确率, + 并累加结果以计算整个数据集的总体RMSE和准确率。 + + 参数: + - **dataset** (Dataset) - 用于计算指标的数据集对象。 + - **generator_flag** (bool) - 一个标志,指示是否使用数据生成器。 + + 返回: + - 包含数据集的总体准确率和RMSE的元组。 + 引发: + - NotImplementedError: 如果指定了不支持的数据源。 -- Gitee