diff --git a/MindFlow/mindflow/core/__init__.py b/MindFlow/mindflow/core/__init__.py index c378b906954f4b7c49839cc106321d6088bd7521..df3b756b6352df1b63e9fd724eb2b9a15adece59 100644 --- a/MindFlow/mindflow/core/__init__.py +++ b/MindFlow/mindflow/core/__init__.py @@ -16,6 +16,7 @@ from .lr_scheduler import get_poly_lr, get_multi_step_lr, get_warmup_cosine_annealing_lr from .losses import get_loss_metric, WaveletTransformLoss, MTLWeightedLoss, RelativeRMSELoss from .derivatives import batched_hessian, batched_jacobian +from .optimizers import AdaHessian __all__ = ["get_poly_lr", "get_multi_step_lr", @@ -26,6 +27,7 @@ __all__ = ["get_poly_lr", "RelativeRMSELoss", "batched_hessian", "batched_jacobian", + "AdaHessian", ] __all__.sort() diff --git a/MindFlow/mindflow/core/optimizers.py b/MindFlow/mindflow/core/optimizers.py new file mode 100644 index 0000000000000000000000000000000000000000..58793c689f29adbf14a2acc664e4dc8d80982ff5 --- /dev/null +++ b/MindFlow/mindflow/core/optimizers.py @@ -0,0 +1,117 @@ +# ============================================================================ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +''' Define costumized optimizers and gradient accumulators ''' +import mindspore as ms +from mindspore import nn, ops + + +class AdaHessian(nn.Adam): + """Implements Adahessian algorithm. + It has been proposed in `ADAHESSIAN: An Adaptive Second Order Optimizer for Machine Learning`. + See the (Torch implementation)[https://github.com/amirgholami/adahessian/blob/master/instruction/adahessian.py] + for reference. + The Hessian power here is fixed to 1, and the way of spatially averaging the Hessian traces follows the default + behavior in the Torch implementation, that is + - for 1D: no spatial average + - for 2D: use the entire row as the spatial average + - for 3D (assume 1D Conv, can be customized): use the last dimension as spatial average + - for 4D (assume 2D Conv, can be customized): use the last 2 dimensions as spatial average + Arguments: + params (iterable): iterable of parameters to optimize + others: other arguments same to Adam + Examples: + >>> import numpy as np + >>> import mindspore as ms + >>> from mindspore import ops, nn + >>> from mindflow import AdaHessian + >>> ms.set_context(device_target="Ascend", mode=ms.GRAPH_MODE) + >>> net = nn.Conv2d(in_channels=2, out_channels=4, kernel_size=3) + >>> def forward(a): + >>> return ops.mean(net(a)**2)**.5 + >>> grad_fn = ms.grad(forward, grad_position=None, weights=net.trainable_params()) + >>> optimizer = AdaHessian(net.trainable_params()) + >>> inputs = ms.Tensor(np.reshape(range(100), [2, 2, 5, 5]), dtype=ms.float32) + >>> optimizer(grad_fn, inputs) + >>> print(optimizer.moment2[0].shape) + (4, 2, 3, 3) + """ + def gen_rand_vecs(self, grads): + return [(2 * ops.randint(0, 2, p.shape) - 1).astype(ms.float32) for p in grads] + + def modify_moments(self, grad_fn, inputs): + """ introduce Hutchinson trace by pre-adding its difference to grads' square into the second moment + """ + # generate the function for 2nd-order derivative + # vjp_fn solve for the derivative of both input and weights + grads, vjp_fn = ms.vjp(grad_fn, inputs, weights=self.parameters) + + # generate random vector + vs = self.gen_rand_vecs(grads) + + # solve for hutchinson trace + # when operator does not support 2nd-order derivative by vjp(), using `hvs = grads` instead + # to make the code run, but the output value would not be correct + _, hvs = vjp_fn(tuple(vs)) + + hutchinson_trace = [] + + for hv in hvs: + hv_abs = hv.abs() + + if hv.ndim <= 1: + hutchinson_trace.append(hv_abs) + elif hv.ndim == 2: + hutchinson_trace.append(ops.mean(hv_abs, axis=[1], keep_dims=True)) + elif hv.ndim == 3: + hutchinson_trace.append(ops.mean(hv_abs, axis=[2], keep_dims=True)) + elif hv.ndim == 4: + hutchinson_trace.append(ops.mean(hv_abs, axis=[2, 3], keep_dims=True)) + else: + raise RuntimeError(f'You need to write your customized function to support this shape: {hv.shape}') + + # modify moment2 + for i in range(len(self.moment2)): + ops.assign( + self.moment2[i], + self.moment2[i] + (1. - self.beta2) * ( + hutchinson_trace[i] + grads[i]) * (hutchinson_trace[i] - grads[i]) / self.beta2) + + return grads + + def construct(self, grad_fn, inputs): + """Update the weights using AdaHessian algorithm + Args: + grad_fn (callable): the function that outputs 1st-order gradients + inputs (Tensor): the inputs to the gradient function + """ + gradients = self.modify_moments(grad_fn, inputs) + + params = self._parameters + moment1 = self.moment1 + moment2 = self.moment2 + gradients = self.flatten_gradients(gradients) + gradients = self.decay_weight(gradients) + if not self.use_offload: + gradients = self.gradients_centralization(gradients) + gradients = self.scale_grad(gradients) + gradients = self._grad_sparse_indices_deduplicate(gradients) + lr = self.get_lr() + self.assignadd(self.global_step, self.global_step_increase_tensor) + + self.beta1_power *= self.beta1 + self.beta2_power *= self.beta2 + + return self._apply_adam(params, self.beta1_power, self.beta2_power, moment1, moment2, lr, gradients) diff --git a/tests/st/mindflow/cell/test_optimizers.py b/tests/st/mindflow/cell/test_optimizers.py new file mode 100644 index 0000000000000000000000000000000000000000..891695344f18dcfbfb785c4fc7d87b0da974479b --- /dev/null +++ b/tests/st/mindflow/cell/test_optimizers.py @@ -0,0 +1,221 @@ +# ============================================================================ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Optimizers Test Case""" +import os +import random +import sys + +import pytest +import numpy as np + +import mindspore as ms +from mindspore import ops, set_seed, nn +from mindspore import dtype as mstype +from mindflow import UNet2D, AttentionBlock, AdaHessian +from mindflow.cell.attention import Mlp +from mindflow.cell.unet2d import Down + +PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../")) +sys.path.append(PROJECT_ROOT) + +# pylint: disable=wrong-import-position + +from common.cell import FP32_RTOL + +# pylint: enable=wrong-import-position + +set_seed(0) +np.random.seed(0) +random.seed(0) + + +class TestAdaHessianAccuracy(AdaHessian): + ''' Child class for testing the accuracy of AdaHessian optimizer ''' + def gen_rand_vecs(self, grads): + ''' generate certain vector for accuracy test ''' + return [ms.Tensor(np.arange(p.size).reshape(p.shape) - p.size // 2, dtype=ms.float32) for p in grads] + + +class TestUNet2D(UNet2D): + ''' Child class for testing optimizing UNet with AdaHessian ''' + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + class TestDown(Down): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + in_channels = args[0] + kernel_size = kwargs['kernel_size'] + stride = kwargs['stride'] + # replace the `maxpool` layer in the original UNet with `conv` to avoid `vjp` problem + self.maxpool = nn.Conv2d(in_channels, in_channels, kernel_size=kernel_size, stride=stride) + + self.layers_down = nn.CellList() + for i in range(self.n_layers): + self.layers_down.append(TestDown(self.base_channels * 2**i, self.base_channels * 2 ** (i+1), + kernel_size=self.kernel_size, stride=self.stride, + activation=self.activation, enable_bn=self.enable_bn)) + + +class TestAttentionBlock(AttentionBlock): + ''' Child class for testing optimizing Attention with AdaHessian ''' + def __init__(self, in_channels, num_heads, drop_mode="dropout", dropout_rate=0.0, compute_dtype=mstype.float16): + super().__init__( + in_channels, num_heads, drop_mode=drop_mode, dropout_rate=dropout_rate, compute_dtype=compute_dtype) + + class TestMlp(Mlp): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.act_fn = nn.ReLU() # replace `gelu` with `relu` to avoid `vjp` problem + + self.ffn = TestMlp(in_channels=in_channels, dropout_rate=dropout_rate, compute_dtype=compute_dtype) + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend910b_training +@pytest.mark.env_onecard +@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) +def test_adahessian_accuracy(mode): + """ + Feature: AdaHessian forward accuracy test + Description: Test the accuracy of the AdaHessian optimizer in both GRAPH_MODE and PYNATIVE_MODE + with input data specified in the code below. + The expected output is compared to a reference output stored in + './mindflow/core/optimizers/data/adahessian_output.npy'. + Expectation: The output should match the target data within the defined relative tolerance, + ensuring the AdaHessian computation is accurate. + """ + ms.set_context(mode=mode) + + weight_init = ms.Tensor(np.reshape(range(72), [4, 2, 3, 3]), dtype=ms.float32) + bias_init = ms.Tensor(np.arange(4), dtype=ms.float32) + + net = nn.Conv2d( + in_channels=2, out_channels=4, kernel_size=3, has_bias=True, weight_init=weight_init, bias_init=bias_init) + + def forward(a): + return ops.mean(net(a)**2)**.5 + + grad_fn = ms.grad(forward, grad_position=None, weights=net.trainable_params()) + + optimizer = TestAdaHessianAccuracy( + net.trainable_params(), + learning_rate=0.1, beta1=0.9, beta2=0.999, eps=1e-8, weight_decay=0.) + + inputs = ms.Tensor(np.reshape(range(100), [2, 2, 5, 5]), dtype=ms.float32) + + for _ in range(4): + optimizer(grad_fn, inputs) + + outputs = net(inputs).numpy() + outputs_ref = np.load('/home/workspace/mindspore_dataset/mindscience/mindflow/optimizers/adahessian_output.npy') + relative_error = np.max(np.abs(outputs - outputs_ref)) / np.max(np.abs(outputs_ref)) + assert relative_error < FP32_RTOL, "The verification of adahessian accuracy is not successful." + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend910b_training +@pytest.mark.env_onecard +@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) +@pytest.mark.parametrize('model_option', ['unet']) +def test_adahessian_st(mode, model_option): + """ + Feature: AdaHessian ST test + Description: Test the function of the AdaHessian optimizer in both GRAPH_MODE and PYNATIVE_MODE + on the complex network such as UNet. The input is a Tensor specified in the code + and the output is the loss after 4 rounds of optimization. + Expectation: The output should be finite, ensuring the AdaHessian runs successfully on UNet. + """ + ms.set_context(mode=mode) + + # default test with Attention network + net = TestAttentionBlock(in_channels=256, num_heads=4) + inputs = ms.Tensor(np.sin(np.reshape(range(102400), [4, 100, 256])), dtype=ms.float32) + + # test with UNet network + if model_option.lower() == 'unet': + net = TestUNet2D( + in_channels=2, + out_channels=4, + base_channels=8, + n_layers=4, + kernel_size=2, + stride=2, + activation='relu', + data_format="NCHW", + enable_bn=True, + ) + inputs = ms.Tensor(np.sin(np.reshape(range(16384), [2, 2, 64, 64])), dtype=ms.float32) + + def forward(a): + return ops.mean(net(a)**2)**.5 + + grad_fn = ms.grad(forward, grad_position=None, weights=net.trainable_params()) + + optimizer = AdaHessian( + net.trainable_params(), + learning_rate=0.1, beta1=0.9, beta2=0.999, eps=1e-8, weight_decay=0.) + + for _ in range(4): + loss = forward(inputs) + optimizer(grad_fn, inputs) + + assert ops.isfinite(loss) + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend910b_training +@pytest.mark.env_onecard +def test_adahessian_compare(): + """ + Feature: AdaHessian compare with Adam + Description: Compare the algorithm results of the AdaHessian optimizer with Adam. + The code runs in PYNATIVE_MODE and the network under comparison is AttentionBlock. + The optimization runs 100 rounds to demonstrate an essential loss decrease. + Expectation: The loss of AdaHessian outperforms Adam by 20% under the same configuration on an Attention network. + """ + ms.set_context(mode=ms.PYNATIVE_MODE) + + def get_loss(optimizer_option): + ''' compare Adam and AdaHessian ''' + net = TestAttentionBlock(in_channels=256, num_heads=4) + inputs = ms.Tensor(np.sin(np.reshape(range(102400), [4, 100, 256])), dtype=ms.float32) + + def forward(a): + return ops.mean(net(a)**2)**.5 + + grad_fn = ms.grad(forward, grad_position=None, weights=net.trainable_params()) + + if optimizer_option.lower() == 'adam': + optimizer = nn.Adam( + net.trainable_params(), + learning_rate=0.01, beta1=0.9, beta2=0.999, eps=1e-8, weight_decay=0.) + else: + optimizer = AdaHessian( + net.trainable_params(), + learning_rate=0.01, beta1=0.9, beta2=0.999, eps=1e-8, weight_decay=0.) + + for _ in range(100): + loss = forward(inputs) + if optimizer_option.lower() == 'adam': + optimizer(grad_fn(inputs)) + else: + optimizer(grad_fn, inputs) + + return loss + + loss_adam = get_loss('adam') + loss_adahessian = get_loss('adahessian') + + assert loss_adam * 0.8 > loss_adahessian, (loss_adam, loss_adahessian)