diff --git a/MindEarth/mindearth/data/dataset.py b/MindEarth/mindearth/data/dataset.py index 88ef3d108b1e5f2750021694ba485eb8371fcf09..761da0de7ffe6745e55578532a06e8253799e624 100644 --- a/MindEarth/mindearth/data/dataset.py +++ b/MindEarth/mindearth/data/dataset.py @@ -132,6 +132,9 @@ class Era5Data(Data): data_params (dict): dataset-related configuration of the model. run_mode (str, optional): whether the dataset is used for training, evaluation or testing. Supports [“train”, “test”, “valid”]. Default: 'train'. + kno_patch (bool, optional): Indicates whether the data is already partitioned into patches. If True, the data + is assumed to be pre-processed and no further patching is performed. If False, the data will be processed + into patches as per the specified parameters. Default: False. Supported Platforms: ``Ascend`` ``GPU`` diff --git a/MindEarth/mindearth/module/forecast.py b/MindEarth/mindearth/module/forecast.py index 95014dc48209ac4da8c29d12eb97e1afbfb3c1d7..f5e4c6163f4b5fbb8aa2ebeb2d35c59de86492e8 100644 --- a/MindEarth/mindearth/module/forecast.py +++ b/MindEarth/mindearth/module/forecast.py @@ -214,6 +214,8 @@ class WeatherForecast: Args: dataset (mindspore.dataset): The dataset for eval, including inputs and labels. + generator_flag (bool): "generator_flag" is used to pass a parameter to the "compute_total_rmse_acc" method. + A flag indicating whether to use a data generator or not. ''' data_length = len(dataset) // self.batch_size self.logger.info("================================Start Evaluation================================") diff --git a/MindEarth/mindearth/module/pretrain.py b/MindEarth/mindearth/module/pretrain.py index f48f845b6c6450a0ad6a39f52f6d15aaf60f03d3..87dce7c8584923f4b5c8488b74d4e7bbbd57f6c3 100644 --- a/MindEarth/mindearth/module/pretrain.py +++ b/MindEarth/mindearth/module/pretrain.py @@ -287,7 +287,9 @@ class Trainer: return solver def train(self): - """train.""" + """ + Execute model training. + """ callback_lst = [LossMonitor(), TimeMonitor()] if self.pred_cb: callback_lst.append(self.pred_cb) diff --git a/docs/api_python/mindearth/data/mindearth.data.Era5Data.rst b/docs/api_python/mindearth/data/mindearth.data.Era5Data.rst index 4df272e466c8426bb8b04412e7845009d871feed..6de60d6f2b427fa9874210df77be3eccbf447241 100644 --- a/docs/api_python/mindearth/data/mindearth.data.Era5Data.rst +++ b/docs/api_python/mindearth/data/mindearth.data.Era5Data.rst @@ -8,5 +8,7 @@ mindearth.data.Era5Data 参数: - **data_params** (dict) - 模型中的相关数据参数。 - **run_mode** (str, 可选) - 决定数据集用于训练、验证还是测试。支持 ``'train'``, ``'test'``, ``'valid'``。默认值: ``'train'``。 + - **kno_patch** (bool, 可选) - 决定数据集是否分割成小块。如果为True,则假定数据已预处理,并且不会执行进一步的分割。如果为False, + 则将根据指定的参数将数据分割成小块。默认值: ``False``。 diff --git a/docs/api_python/mindearth/module/mindearth.module.Trainer.rst b/docs/api_python/mindearth/module/mindearth.module.Trainer.rst index 71aca251d891b90cdd4b85704ef3d72fa9cf2e89..4c43f8e264978fa3d87ce0c7f46dcea2914eb150 100644 --- a/docs/api_python/mindearth/module/mindearth.module.Trainer.rst +++ b/docs/api_python/mindearth/module/mindearth.module.Trainer.rst @@ -56,4 +56,17 @@ mindearth.module.Trainer .. py:method:: mindearth.module.Trainer.train() - 执行模型训练。 \ No newline at end of file + 执行模型训练。 + + .. py:method:: mindearth.module.Trainer.get_data_generator() + + 生成用于训练和验证数据集的数据生成器。 + + 该函数根据指定的天气数据源创建数据生成器。 + 支持 'ERA5' 和 'DemSR' 数据源,对于不支持的数据源将引发错误。 + + 返回: + - 包含训练和验证数据生成器的元组。 + + 异常: + - **NotImplementedError** - 如果指定了不支持的数据源。 diff --git a/docs/api_python/mindearth/module/mindearth.module.WeatherForecast.rst b/docs/api_python/mindearth/module/mindearth.module.WeatherForecast.rst index b50d420514e6c1507d39fd0dcb64a38e7168598b..9fe1df932dbc0bcf78984a68a87e5902bb9611bc 100644 --- a/docs/api_python/mindearth/module/mindearth.module.WeatherForecast.rst +++ b/docs/api_python/mindearth/module/mindearth.module.WeatherForecast.rst @@ -22,6 +22,7 @@ mindearth.module.WeatherForecast 参数: - **dataset** (mindspore.dataset) - 模型推理数据集,包括输入值和样本值。 + - **generator_flag** (bool,可选) - 用于向 "compute_total_rmse_acc" 方法传递一个参数。指示是否使用数据生成器。 .. py:method:: mindearth.module.WeatherForecast.forecast(inputs, labels=None) :staticmethod: @@ -31,4 +32,20 @@ mindearth.module.WeatherForecast 参数: - **inputs** (Tensor) - 模型的输入数据。 - **labels** (Tensor) - 样本真实数据。默认值: ``None``。 + + .. py:method:: mindearth.module.WeatherForecast.compute_total_rmse_acc(dataset, generator_flag) + 计算数据集的总体均方根误差(RMSE)和准确率。 + + 该函数遍历数据集,为每个批次计算RMSE和准确率, + 并累加结果以计算整个数据集的总体RMSE和准确率。 + + 参数: + - **dataset** (Dataset) - 用于计算指标的数据集对象。 + - **generator_flag** (bool) - 一个标志,指示是否使用数据生成器。 + + 返回: + - 包含数据集的总体准确率和RMSE的元组。 + + 引发: + - NotImplementedError: 如果指定了不支持的数据源。