From b5a26dcdb2ce7de98878ff7fe3cc5186c4f96c19 Mon Sep 17 00:00:00 2001 From: goto Date: Wed, 9 Apr 2025 16:24:23 +0800 Subject: [PATCH] add doc --- MindFlow/mindflow/core/optimizers.py | 29 ++++++++++++------- .../core/mindflow.core.AdaHessian.rst | 16 ++++++++++ docs/api_python/mindflow/mindflow.core.rst | 1 + docs/api_python_en/mindflow/mindflow.core.rst | 1 + 4 files changed, 36 insertions(+), 11 deletions(-) create mode 100644 docs/api_python/mindflow/core/mindflow.core.AdaHessian.rst diff --git a/MindFlow/mindflow/core/optimizers.py b/MindFlow/mindflow/core/optimizers.py index 58793c689..e98ccc12f 100644 --- a/MindFlow/mindflow/core/optimizers.py +++ b/MindFlow/mindflow/core/optimizers.py @@ -19,19 +19,25 @@ from mindspore import nn, ops class AdaHessian(nn.Adam): - """Implements Adahessian algorithm. - It has been proposed in `ADAHESSIAN: An Adaptive Second Order Optimizer for Machine Learning`. - See the (Torch implementation)[https://github.com/amirgholami/adahessian/blob/master/instruction/adahessian.py] - for reference. + r""" + The Adahessian optimizer. + It has been proposed in `ADAHESSIAN: An Adaptive Second Order Optimizer for Machine Learning + `_ . + See the `Torch implementation + `_ for reference. The Hessian power here is fixed to 1, and the way of spatially averaging the Hessian traces follows the default behavior in the Torch implementation, that is - - for 1D: no spatial average - - for 2D: use the entire row as the spatial average - - for 3D (assume 1D Conv, can be customized): use the last dimension as spatial average - - for 4D (assume 2D Conv, can be customized): use the last 2 dimensions as spatial average - Arguments: - params (iterable): iterable of parameters to optimize - others: other arguments same to Adam + + - for 1D: no spatial average. + - for 2D: use the entire row as the spatial average. + - for 3D (assume 1D Conv, can be customized): use the last dimension as spatial average. + - for 4D (assume 2D Conv, can be customized): use the last 2 dimensions as spatial average. + + Args see `mindspore.nn.Adam `_ . + + Supported Platforms: + ``Ascend`` + Examples: >>> import numpy as np >>> import mindspore as ms @@ -48,6 +54,7 @@ class AdaHessian(nn.Adam): >>> print(optimizer.moment2[0].shape) (4, 2, 3, 3) """ + def gen_rand_vecs(self, grads): return [(2 * ops.randint(0, 2, p.shape) - 1).astype(ms.float32) for p in grads] diff --git a/docs/api_python/mindflow/core/mindflow.core.AdaHessian.rst b/docs/api_python/mindflow/core/mindflow.core.AdaHessian.rst new file mode 100644 index 000000000..6b75f82c0 --- /dev/null +++ b/docs/api_python/mindflow/core/mindflow.core.AdaHessian.rst @@ -0,0 +1,16 @@ +mindflow.core.AdaHessian +========================= + +.. py:class:: mindflow.core.AdaHessian(*args, **kwargs) + + 二阶优化器 AdaHessian,利用 Hessian 矩阵对角元信息进行二阶优化求解。 + 有关更多详细信息,请参考论文 `ADAHESSIAN: An Adaptive Second Order Optimizer for Machine Learning `_ 。 + 相关 Torch 版本实现可参考 `Torch 版代码 `_ 。 + 此处 Hessian power 固定为 1,且对 Hessian 对角元做空间平均的方法与 Torch 实现的默认行为一致,描述如下: + + - 对于 1D 张量:不做空间平均; + - 对于 2D 张量:做行平均; + - 对于 3D 张量(假设为 1D 卷积):对最后一个维度做平均; + - 对于 4D 张量(假设为 2D 卷积):对最后两个维度做平均。 + + 参数说明详见 `mindspore.nn.Adam `_ 。 \ No newline at end of file diff --git a/docs/api_python/mindflow/mindflow.core.rst b/docs/api_python/mindflow/mindflow.core.rst index 3bf67302f..0a50ae493 100644 --- a/docs/api_python/mindflow/mindflow.core.rst +++ b/docs/api_python/mindflow/mindflow.core.rst @@ -6,6 +6,7 @@ mindflow.core :nosignatures: :template: classtemplate.rst + mindflow.core.AdaHessian mindflow.core.get_multi_step_lr mindflow.core.get_poly_lr mindflow.core.get_warmup_cosine_annealing_lr diff --git a/docs/api_python_en/mindflow/mindflow.core.rst b/docs/api_python_en/mindflow/mindflow.core.rst index 3d4d207df..17ef48a00 100644 --- a/docs/api_python_en/mindflow/mindflow.core.rst +++ b/docs/api_python_en/mindflow/mindflow.core.rst @@ -6,6 +6,7 @@ mindflow.core :nosignatures: :template: classtemplate.rst + mindflow.core.AdaHessian mindflow.core.get_multi_step_lr mindflow.core.get_poly_lr mindflow.core.get_warmup_cosine_annealing_lr -- Gitee