diff --git a/MindEarth/mindearth/module/pretrain.py b/MindEarth/mindearth/module/pretrain.py index 6b486bbe6a2141da05787cf65f17725f594a3557..9503e401bf8e3128777a2e3ca9613a656535e7f8 100644 --- a/MindEarth/mindearth/module/pretrain.py +++ b/MindEarth/mindearth/module/pretrain.py @@ -52,7 +52,7 @@ class Trainer: model (mindspore.nn.Cell): network for training. loss_fn (mindspore.nn.Cell): loss function. logger (logging.RootLogger, optional): logger of the training process. Default: None. - weatherdata_type (str, optional): the dataset type. Default: 'Era5Data'. + weather_data_source (str, optional): the dataset type. Default: 'ERA5'. loss_scale (mindspore.amp.LossScaleManager, optional): the class of loss scale manager when using mixed precision. Default: mindspore.amp.DynamicLossScaleManager(). diff --git a/docs/api_python/mindearth/data/mindearth.data.Era5Data.rst b/docs/api_python/mindearth/data/mindearth.data.Era5Data.rst index 0f231faac7a3905fe827abfbfdf62cec96f908d9..4fdc6c63592dc3b360db76824b13c1b38a080c30 100644 --- a/docs/api_python/mindearth/data/mindearth.data.Era5Data.rst +++ b/docs/api_python/mindearth/data/mindearth.data.Era5Data.rst @@ -1,7 +1,7 @@ mindearth.data.Era5Data ========================= -.. py:class:: mindearth.data.Era5Data(data_params, run_mode='train') +.. py:class:: mindearth.data.Era5Data(data_params, run_mode='train', kno_patch=False) Era5Data类通过MindSpore框架处理ERA5数据集生成数据生成器。Era5Data类继承了Data类。 diff --git a/docs/api_python/mindearth/module/mindearth.module.Trainer.rst b/docs/api_python/mindearth/module/mindearth.module.Trainer.rst index 4c43f8e264978fa3d87ce0c7f46dcea2914eb150..6377c20d72bcde68abeb2000ab2f0c496377cbd9 100644 --- a/docs/api_python/mindearth/module/mindearth.module.Trainer.rst +++ b/docs/api_python/mindearth/module/mindearth.module.Trainer.rst @@ -13,7 +13,7 @@ mindearth.module.Trainer - **model** (mindspore.nn.Cell) - 用于训练的网络。 - **loss_fn** (mindspore.nn.Cell) - 损失函数。 - **logger** (logging.RootLogger, 可选) - 训练过程中的日志模块。默认值: ``None``。 - - **weatherdata_type** (str, 可选) - 数据的类型。默认值: ``Era5Data``。 + - **weather_data_source** (str, 可选) - 数据的类型。默认值: ``ERA5``。 - **loss_scale** (mindspore.amp.LossScaleManager, 可选) - 使用混合精度时,用于管理损失缩放系数的类。默认值: ``mindspore.amp.DynamicLossScaleManager()``。 @@ -21,16 +21,18 @@ mindearth.module.Trainer - **TypeError** - 如果 `model` 或 `loss_fn` 不是mindspore.nn.Cell。 - **NotImplementedError** - 如果 `get_callback` 的方法没有实现。 - .. py:method:: mindearth.module.Trainer.get_callback() - - 用于定义模型的回调类。用户必须自定义重写该方法。 + .. py:method:: mindearth.module.Trainer.get_data_generator() - .. py:method:: mindearth.module.Trainer.get_checkpoint() + 生成用于训练和验证数据集的数据生成器。 - 获得模型的checkpoint实例。 + 该函数根据指定的天气数据源创建数据生成器。 + 支持 'ERA5' 和 'DemSR' 数据源,对于不支持的数据源将引发错误。 返回: - Callback,模型的checkpoint实例. + - 包含训练和验证数据生成器的元组。 + + 异常: + - **NotImplementedError** - 如果指定了不支持的数据源。 .. py:method:: mindearth.module.Trainer.get_dataset() @@ -47,6 +49,17 @@ mindearth.module.Trainer 返回: Optimizer,模型的优化器。 + .. py:method:: mindearth.module.Trainer.get_checkpoint() + + 获得模型的checkpoint实例。 + + 返回: + Callback,模型的checkpoint实例. + + .. py:method:: mindearth.module.Trainer.get_callback() + + 用于定义模型的回调类。用户必须自定义重写该方法。 + .. py:method:: mindearth.module.Trainer.get_solver() 获得模型训练的求解器。 @@ -58,15 +71,3 @@ mindearth.module.Trainer 执行模型训练。 - .. py:method:: mindearth.module.Trainer.get_data_generator() - - 生成用于训练和验证数据集的数据生成器。 - - 该函数根据指定的天气数据源创建数据生成器。 - 支持 'ERA5' 和 'DemSR' 数据源,对于不支持的数据源将引发错误。 - - 返回: - - 包含训练和验证数据生成器的元组。 - - 异常: - - **NotImplementedError** - 如果指定了不支持的数据源。 diff --git a/docs/api_python/mindearth/module/mindearth.module.WeatherForecast.rst b/docs/api_python/mindearth/module/mindearth.module.WeatherForecast.rst index 3f0e026fa82f8830565e8a8e9ad90a6483b566de..0876c03e8ec98a28b6de7514f7dba77ac701a4d6 100644 --- a/docs/api_python/mindearth/module/mindearth.module.WeatherForecast.rst +++ b/docs/api_python/mindearth/module/mindearth.module.WeatherForecast.rst @@ -16,23 +16,22 @@ mindearth.module.WeatherForecast .. note:: 需要重写其中的成员函数 `forecast` 用于定义模型推理的前向过程。 - .. py:method:: mindearth.module.WeatherForecast.eval(dataset) - - 根据验证集数据或测试集数据执行模型推理。 - - 参数: - - **dataset** (mindspore.dataset) - 模型推理数据集,包括输入值和样本值。 - - **generator_flag** (bool, 可选) - 用于向 "compute_total_rmse_acc" 方法传递一个参数。指示是否使用数据生成器。 - .. py:method:: mindearth.module.WeatherForecast.forecast(inputs, labels=None) - :staticmethod: 模型的预测方法。 参数: - **inputs** (Tensor) - 模型的输入数据。 - **labels** (Tensor) - 样本真实数据。默认值: ``None``。 - + + .. py:method:: mindearth.module.WeatherForecast.eval(dataset, generator_flag=False) + + 根据验证集数据或测试集数据执行模型推理。 + + 参数: + - **dataset** (mindspore.dataset) - 模型推理数据集,包括输入值和样本值。 + - **generator_flag** (bool) - 用于向 "compute_total_rmse_acc" 方法传递一个参数。指示是否使用数据生成器。默认值: ``False``。 + .. py:method:: mindearth.module.WeatherForecast.compute_total_rmse_acc(dataset, generator_flag) 计算数据集的总体均方根误差(RMSE)和准确率。