diff --git a/MindEarth/applications/sea/LeadFormer/LeadFormer.ipynb b/MindEarth/applications/sea/LeadFormer/LeadFormer.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..0c3148f2e6b4c9baf95b590de471de7bc181bdcc --- /dev/null +++ b/MindEarth/applications/sea/LeadFormer/LeadFormer.ipynb @@ -0,0 +1,554 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "aa1f8be6-a0a1-4842-a79f-7eb8cf34c4f9", + "metadata": {}, + "source": [ + "# LeadFormer: 北极海冰高分辨率智能预报¶\n", + "\n", + "## 概述\n", + "\n", + "冰间水道是海冰在海浪、风力和洋流作用下形成的线状断裂带,其形态特征能够反映海洋与大气之间物质能量交换的强度,影响着水道表面的湍流热通量。因此,冰间水道的形态及空间分布的准确刻画对研究北极的海冰变化和预测航道通航具有重要意义。冰间水道的形态特征包括长度、宽度和倾角等。冰间水道宽度在一定程度上决定了大气和海洋水热交换的强度,水道倾角反应且影响海冰动力学特征,水道总长度可以作为衡量冰间水道尺度变异及季节和年际变化的指标。高分辨率海冰冰间水道预测模型是当前应对全球气候变暖背景下北极海冰快速变化的关键技术工具。针对海冰变化机理的复杂性和海冰预报的不确定性,***LeadFormer***以北极高分辨率数值模式数据和基于transformer的人工智能模型为支撑,实现北极冰间水道的智能预报,区域覆盖泛北极,分辨率达到2km的高分辨率冰情预报体系。\n", + "\n", + "![LeadFormer](images/model.png)\n", + "\n", + "该模型采用编码器-解码器框架,编码阶段通过重叠块嵌入和四级Transformer块实现特征压缩与深化;解码阶段通过MLP和上采样操作逐步重建空间维度;核心创新在于融合Transformer的全局建模能力与CNN的局部感知特性,适用于高精度图像处理任务。\n", + "\n", + "本模型数据集暂不开源,仅开源代码。" + ] + }, + { + "cell_type": "markdown", + "id": "ad6f6a1a-8fc9-4202-b6c4-5521f0b76834", + "metadata": {}, + "source": [ + "## 概述\n", + "\n", + "MindEarth求解该问题的具体流程如下:\n", + "\n", + "1.模型构建\n", + "\n", + "2.模型训练\n", + "\n", + "3.模型评估与可视化\n", + "\n", + "本模型数据集暂不开源,仅开源代码。" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "76482de7-2f70-4176-8cc2-ac411d95fce5", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[WARNING] ME(1012809:281472904290336,MainProcess):2025-08-07-15:36:02.292.000 [mindspore/run_check/_check_version.py:402] Can not find the tbe operator implementation(need by mindspore-ascend). Please check whether the Environment Variable PYTHONPATH is set. For details, refer to the installation guidelines: https://www.mindspore.cn/install\n" + ] + } + ], + "source": [ + "import mindspore as ms\n", + "from mindspore import set_seed, context" + ] + }, + { + "cell_type": "markdown", + "id": "0a61d243-a38d-406d-b24b-f310d2e083a4", + "metadata": {}, + "source": [ + "下述src可以在[LeadFormer/src](./src)下载。" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "92350b28-3bf3-4cc3-a120-ad3ba0cde982", + "metadata": {}, + "outputs": [], + "source": [ + "from mindearth.utils import load_yaml_config\n", + "\n", + "from src.solver import Trainer\n", + "from src.forecast import Tester" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "758ef133-7955-438b-acfb-dcad8858ab45", + "metadata": {}, + "outputs": [], + "source": [ + "set_seed(0)" + ] + }, + { + "cell_type": "markdown", + "id": "d54b33c2-1572-48c0-8300-75a3d96c549c", + "metadata": {}, + "source": [ + "可以在[配置文件](./configs/2km_ice_config.yaml)中配置模型、数据和优化器等参数。" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "6857d73b-358d-4e85-818e-7fa5f089a84f", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[WARNING] ME(1012809:281472904290336,MainProcess):2025-08-07-15:36:09.777.000 [mindspore/run_check/_check_version.py:402] Can not find the tbe operator implementation(need by mindspore-ascend). Please check whether the Environment Variable PYTHONPATH is set. For details, refer to the installation guidelines: https://www.mindspore.cn/install\n" + ] + } + ], + "source": [ + "context.set_context(mode=ms.PYNATIVE_MODE)\n", + "ms.set_device(device_target=\"Ascend\", device_id=4)\n", + "config = load_yaml_config(\"./configs/2km_ice_config.yaml\")" + ] + }, + { + "cell_type": "markdown", + "id": "eb526cde-3ea4-4050-8fe1-45f6b2d14eca", + "metadata": {}, + "source": [ + "## 模型训练\n", + "\n", + "在本教程中,我们使用Trainer对模型进行训练。" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "668a976d-ffdd-4807-b90a-e1baefcb783e", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[WARNING] ME(3419713:281473115082784,MainProcess):2025-08-07-10:56:00.474.000 [mindspore/dataset/core/config.py:685] The shared memory is on, multiprocessing performance will be improved. Note: the required shared memory can't exceeds 80% of the available shared memory.\n", + "[WARNING] ME(3419713:281473115082784,MainProcess):2025-08-07-10:56:01.280.00 [mindspore/run_check/_check_version.py:305] The version 7.6 used for compiling the custom operator does not match Ascend AI software package version 7.5 in the current environment.\n", + "[WARNING] ME(3419713:281473115082784,MainProcess):2025-08-07-10:56:01.330.00 [mindspore/train/model.py:1419] For StepLossTimeMonitor callback, {'step_begin', 'epoch_end', 'step_end', 'epoch_begin'} methods may not be supported in later version, Use methods prefixed with 'on_train' or 'on_eval' instead when using customized callbacks.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "============== Starting Training ==============\n", + "==============================================================\n", + ".............step: 1, loss is 0.2585278, fps is 0.007067410267447278, lr is 1e-04\n", + "step: 2, loss is 1.3312725, fps is 0.18855858711971155, lr is 1e-04\n", + "step: 3, loss is 0.39229235, fps is 0.27716471662093956, lr is 1e-04\n", + "step: 4, loss is 0.34179375, fps is 0.3300149329982123, lr is 1e-04\n", + "step: 5, loss is 0.4487785, fps is 0.3321377820521602, lr is 1e-04\n", + "step: 6, loss is 0.17796662, fps is 0.28678993865574326, lr is 1e-04\n", + "step: 7, loss is 0.17500843, fps is 0.2804825260482442, lr is 1e-04\n", + "step: 8, loss is 0.1605135, fps is 0.32581367309958326, lr is 1e-04\n", + "step: 9, loss is 0.1647645, fps is 0.27461646180829363, lr is 1e-04\n", + "epoch: 1, avg loss:0.3834, total cost: 176.468 s, per step fps:0.051\n", + "step: 1, loss is 0.19449174, fps is 0.4835091969602687, lr is 9.9726094e-05\n", + "step: 2, loss is 0.07440963, fps is 0.4702539411465023, lr is 9.9726094e-05\n", + "step: 3, loss is 0.27065408, fps is 0.46883899778116844, lr is 9.9726094e-05\n", + "step: 4, loss is 0.09641305, fps is 0.47306889037293376, lr is 9.9726094e-05\n", + "step: 5, loss is 0.15553927, fps is 0.4792626517483649, lr is 9.9726094e-05\n", + "step: 6, loss is 0.13678743, fps is 0.47240835108772194, lr is 9.9726094e-05\n", + "step: 7, loss is 0.18045919, fps is 0.4727445974472515, lr is 9.9726094e-05\n", + "step: 8, loss is 0.121810436, fps is 0.4675605187249209, lr is 9.9726094e-05\n", + "step: 9, loss is 0.098566085, fps is 0.46855810855662056, lr is 9.9726094e-05\n", + "epoch: 2, avg loss:0.1477, total cost: 20.548 s, per step fps:0.438\n", + "step: 1, loss is 0.19244102, fps is 0.46974691839125243, lr is 9.890738e-05\n", + "step: 2, loss is 0.15396605, fps is 0.47269851165984045, lr is 9.890738e-05\n", + "step: 3, loss is 0.21747951, fps is 0.47871531890983626, lr is 9.890738e-05\n", + "step: 4, loss is 0.22973962, fps is 0.4678759613964494, lr is 9.890738e-05\n", + "step: 5, loss is 0.14149912, fps is 0.4686135474245217, lr is 9.890738e-05\n", + "step: 6, loss is 0.13367249, fps is 0.4676195275534368, lr is 9.890738e-05\n", + "step: 7, loss is 0.10203815, fps is 0.46257949112665014, lr is 9.890738e-05\n", + "step: 8, loss is 0.13874854, fps is 0.468793670291153, lr is 9.890738e-05\n", + "step: 9, loss is 0.181213, fps is 0.4711711134851441, lr is 9.890738e-05\n", + "epoch: 3, avg loss:0.1656, total cost: 20.472 s, per step fps:0.440\n", + "step: 1, loss is 0.23165968, fps is 0.4659994491516903, lr is 9.7552824e-05\n", + "step: 2, loss is 0.072046176, fps is 0.4671766475833158, lr is 9.7552824e-05\n", + "step: 3, loss is 0.23098187, fps is 0.47607396613513847, lr is 9.7552824e-05\n", + "step: 4, loss is 0.17709546, fps is 0.4714999288417208, lr is 9.7552824e-05\n", + "step: 5, loss is 0.13153948, fps is 0.47837914901673634, lr is 9.7552824e-05\n", + "step: 6, loss is 0.10806603, fps is 0.46920408551603504, lr is 9.7552824e-05\n", + "step: 7, loss is 0.2324798, fps is 0.4703840462320135, lr is 9.7552824e-05\n", + "step: 8, loss is 0.20180652, fps is 0.47566892757473156, lr is 9.7552824e-05\n", + "step: 9, loss is 0.17364886, fps is 0.4421808711280614, lr is 9.7552824e-05\n", + "......\n", + "epoch: 28, avg loss:0.1316, total cost: 20.141 s, per step fps:0.447\n", + "step: 1, loss is 0.096311785, fps is 0.48387031277398873, lr is 1.09262e-06\n", + "step: 2, loss is 0.1910012, fps is 0.4785059284686627, lr is 1.09262e-06\n", + "step: 3, loss is 0.09300855, fps is 0.4757382027972377, lr is 1.09262e-06\n", + "step: 4, loss is 0.06182714, fps is 0.4843480540296477, lr is 1.09262e-06\n", + "step: 5, loss is 0.1028601, fps is 0.4660128072014537, lr is 1.09262e-06\n", + "step: 6, loss is 0.0648559, fps is 0.4695110811807986, lr is 1.09262e-06\n", + "step: 7, loss is 0.2138748, fps is 0.4715329522445039, lr is 1.09262e-06\n", + "step: 8, loss is 0.18132131, fps is 0.4756527446825787, lr is 1.09262e-06\n", + "step: 9, loss is 0.070166126, fps is 0.4750834309048936, lr is 1.09262e-06\n", + "epoch: 29, avg loss:0.1195, total cost: 20.285 s, per step fps:0.444\n", + "step: 1, loss is 0.16265991, fps is 0.4689126933090539, lr is 2.7390524e-07\n", + "step: 2, loss is 0.06894889, fps is 0.46606837040793697, lr is 2.7390524e-07\n", + "step: 3, loss is 0.13682221, fps is 0.468522307987637, lr is 2.7390524e-07\n", + "step: 4, loss is 0.14170526, fps is 0.47180951256070847, lr is 2.7390524e-07\n", + "step: 5, loss is 0.07896267, fps is 0.47122563713087756, lr is 2.7390524e-07\n", + "step: 6, loss is 0.106654026, fps is 0.4698092168453713, lr is 2.7390524e-07\n", + "step: 7, loss is 0.1188283, fps is 0.4782360775717037, lr is 2.7390524e-07\n", + "step: 8, loss is 0.146635, fps is 0.46966906855383656, lr is 2.7390524e-07\n", + "step: 9, loss is 0.20011184, fps is 0.46569113259110906, lr is 2.7390524e-07\n", + "epoch: 30, avg loss:0.1290, total cost: 20.339 s, per step fps:0.442\n", + "============== End Training ==============\n" + ] + } + ], + "source": [ + "epoch_size = config[\"train\"].get(\"epochs\", 300)\n", + "trainer = Trainer(config, epochs=epoch_size)\n", + "trainer.train()" + ] + }, + { + "cell_type": "markdown", + "id": "19ade00d-8f02-4d2e-815f-05af82875883", + "metadata": {}, + "source": [ + "## 模型评估\n", + "\n", + "完成训练后,我们使用第30个epoch的权重进行推理。下述展示了预测值与实际值之间的误差和各项指标。" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5c8d5f01-7dcc-46ff-933d-0f3ce82775ad", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[WARNING] ME(1012809:281472904290336,MainProcess):2025-08-07-15:36:31.412.000 [mindspore/train/serialization.py:1956] For 'load_param_into_net', remove parameter prefix name: model., continue to load.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "==================================================\n", + "Missing keys: []\n", + "Checkpoint loaded successfully!\n", + "==================================================\n", + "Total Parameters: 14214275\n", + "Test dataset size: 1\n", + ".ice_input_20170109.npy RMSE: 0.06927543270552072\n", + "Average RMSE: 0.06927543270552072\n", + "Maximum RMSE: 0.06927543270552072\n", + "Start evaluate!\n", + "ice_input_20170107.npy\n", + "========================================= break ==================================================\n", + "========================================= break ==================================================\n", + "ice_input_20170102.npy\n", + "========================================= break ==================================================\n", + "========================================= break ==================================================\n", + "ice_input_20170101.npy\n", + "========================================= break ==================================================\n", + "========================================= break ==================================================\n", + "ice_input_20170106.npy\n", + "========================================= break ==================================================\n", + "========================================= break ==================================================\n", + "ice_input_20170108.npy\n", + "========================================= break ==================================================\n", + "========================================= break ==================================================\n", + "ice_input_20170103.npy\n", + "========================================= break ==================================================\n", + "========================================= break ==================================================\n", + "ice_input_20170109.npy\n", + "========================================= break ==================================================\n", + "========================================= break ==================================================\n", + "ice_input_20170104.npy\n", + "========================================= break ==================================================\n", + "========================================= break ==================================================\n", + "ice_input_20170105.npy\n", + "========================================= break ==================================================\n", + "========================================= break ==================================================\n", + "ice_input_20170110.npy\n", + "========================================= break ==================================================\n", + "========================================= break ==================================================\n", + "ice_input_20170107.npy\n", + "========================================= break ==================================================\n", + "========================================= break ==================================================\n", + "ice_input_20170102.npy\n", + "========================================= break ==================================================\n", + "========================================= break ==================================================\n", + "ice_input_20170101.npy\n", + "========================================= break ==================================================\n", + "========================================= break ==================================================\n", + "ice_input_20170106.npy\n", + "========================================= break ==================================================\n", + "========================================= break ==================================================\n", + "ice_input_20170108.npy\n", + "========================================= break ==================================================\n", + "========================================= break ==================================================\n", + "ice_input_20170103.npy\n", + "========================================= break ==================================================\n", + "========================================= break ==================================================\n", + "ice_input_20170109.npy\n", + "========================================= break ==================================================\n", + "========================================= break ==================================================\n", + "ice_input_20170104.npy\n", + "========================================= break ==================================================\n", + "========================================= break ==================================================\n", + "ice_input_20170105.npy\n", + "========================================= break ==================================================\n", + "========================================= break ==================================================\n", + "ice_input_20170110.npy\n", + "========================================= break ==================================================\n", + "========================================= break ==================================================\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1212/1212 [08:57<00:00, 2.26it/s]\n", + "/tmp/ipykernel_1012809/934503037.py:629: RuntimeWarning: All-NaN slice encountered\n", + " disnrst = np.nanmin(dismin, axis=0)\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "247.6625131308501 0.0 21.980495081736816\n", + "detect_result_ice_input_20170108.npy dis width is: 1.3631642412525773\n", + "detect_result_ice_input_20170108.npy dis diff is: 0.1042636932498877\n", + "detect_result_ice_input_20170108.npy degree diff is: 4.717008703899891\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1372/1372 [10:26<00:00, 2.19it/s]\n", + "/tmp/ipykernel_1012809/934503037.py:629: RuntimeWarning: All-NaN slice encountered\n", + " disnrst = np.nanmin(dismin, axis=0)\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "289.4160607063347 0.0 27.346576883106486\n", + "detect_result_ice_input_20170104.npy dis width is: 0.9752119597121295\n", + "detect_result_ice_input_20170104.npy dis diff is: 0.10289706432384488\n", + "detect_result_ice_input_20170104.npy degree diff is: 6.996797847345313\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1381/1381 [08:49<00:00, 2.61it/s]\n", + "/tmp/ipykernel_1012809/934503037.py:629: RuntimeWarning: All-NaN slice encountered\n", + " disnrst = np.nanmin(dismin, axis=0)\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "311.4264251066365 0.0 23.031132891707415\n", + "detect_result_ice_input_20170103.npy dis width is: 1.10339485188626\n", + "detect_result_ice_input_20170103.npy dis diff is: 0.0864150463915106\n", + "detect_result_ice_input_20170103.npy degree diff is: 6.265153611048445\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1303/1303 [11:06<00:00, 1.96it/s]\n", + "/tmp/ipykernel_1012809/934503037.py:629: RuntimeWarning: All-NaN slice encountered\n", + " disnrst = np.nanmin(dismin, axis=0)\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "256.20677516485057 0.0 21.702993657408058\n", + "detect_result_ice_input_20170110.npy dis width is: 1.1094030108893287\n", + "detect_result_ice_input_20170110.npy dis diff is: 0.08734691066267006\n", + "detect_result_ice_input_20170110.npy degree diff is: 5.833586445174906\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1302/1302 [10:10<00:00, 2.13it/s]\n", + "/tmp/ipykernel_1012809/934503037.py:629: RuntimeWarning: All-NaN slice encountered\n", + " disnrst = np.nanmin(dismin, axis=0)\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "239.01588616995642 0.0 19.85132109877317\n", + "detect_result_ice_input_20170107.npy dis width is: 1.165828382806233\n", + "detect_result_ice_input_20170107.npy dis diff is: 0.08832406718394695\n", + "detect_result_ice_input_20170107.npy degree diff is: 6.283825075937445\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1301/1301 [11:51<00:00, 1.83it/s]\n", + "/tmp/ipykernel_1012809/934503037.py:629: RuntimeWarning: All-NaN slice encountered\n", + " disnrst = np.nanmin(dismin, axis=0)\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "215.84032367222633 0.0 21.73022214705445\n", + "detect_result_ice_input_20170106.npy dis width is: 0.9514659075975823\n", + "detect_result_ice_input_20170106.npy dis diff is: 0.14328743272156802\n", + "detect_result_ice_input_20170106.npy degree diff is: 7.559965475096256\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1291/1291 [08:31<00:00, 2.52it/s]\n", + "/tmp/ipykernel_1012809/934503037.py:629: RuntimeWarning: All-NaN slice encountered\n", + " disnrst = np.nanmin(dismin, axis=0)\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "346.3092233456352 0.0 29.625454195807148\n", + "detect_result_ice_input_20170109.npy dis width is: 0.9989966988881656\n", + "detect_result_ice_input_20170109.npy dis diff is: 0.10743155439241363\n", + "detect_result_ice_input_20170109.npy degree diff is: 6.83757206490284\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1309/1309 [08:40<00:00, 2.51it/s]\n", + "/tmp/ipykernel_1012809/934503037.py:629: RuntimeWarning: All-NaN slice encountered\n", + " disnrst = np.nanmin(dismin, axis=0)\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "941.0408776799075 0.0 24.629606081423983\n", + "detect_result_ice_input_20170101.npy dis width is: 0.9382144248848047\n", + "detect_result_ice_input_20170101.npy dis diff is: 0.09824832831725389\n", + "detect_result_ice_input_20170101.npy degree diff is: 5.501169243719244\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1328/1328 [08:24<00:00, 2.63it/s]\n", + "/tmp/ipykernel_1012809/934503037.py:629: RuntimeWarning: All-NaN slice encountered\n", + " disnrst = np.nanmin(dismin, axis=0)\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "335.33248763098686 0.0 33.43748652203445\n", + "detect_result_ice_input_20170102.npy dis width is: 1.3782899260505495\n", + "detect_result_ice_input_20170102.npy dis diff is: 0.09245287652421691\n", + "detect_result_ice_input_20170102.npy degree diff is: 6.778529156282445\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1359/1359 [12:32<00:00, 1.81it/s]\n", + "/tmp/ipykernel_1012809/934503037.py:629: RuntimeWarning: All-NaN slice encountered\n", + " disnrst = np.nanmin(dismin, axis=0)\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "485.3900802751177 0.0 20.469466739220614\n", + "detect_result_ice_input_20170105.npy dis width is: 2.167959851413929\n", + "detect_result_ice_input_20170105.npy dis diff is: 0.0741817919181337\n", + "detect_result_ice_input_20170105.npy degree diff is: 5.95075356676178\n", + "avg diff width: 1.215192925538156\n", + "avg diff dis: 0.09848487656854464\n", + "avg diff degree: 6.272436119016858\n", + "max diff width: 2.167959851413929\n", + "max diff dis: 0.14328743272156802\n", + "max diff degree: 7.559965475096256\n", + "acc for detect_result_ice_input_20170108.npy is: 0.98944425\n", + "acc for detect_result_ice_input_20170104.npy is: 0.98901625\n", + "acc for detect_result_ice_input_20170103.npy is: 0.98956825\n", + "acc for detect_result_ice_input_20170110.npy is: 0.98920175\n", + "acc for detect_result_ice_input_20170107.npy is: 0.989827\n", + "acc for detect_result_ice_input_20170106.npy is: 0.988562\n", + "acc for detect_result_ice_input_20170109.npy is: 0.989696\n", + "acc for detect_result_ice_input_20170101.npy is: 0.9889415\n", + "acc for detect_result_ice_input_20170102.npy is: 0.9892355\n", + "acc for detect_result_ice_input_20170105.npy is: 0.9891005\n", + "avg acc is: 0.9892593\n", + "Evaluation completed!\n" + ] + } + ], + "source": [ + "evaluator = Tester(config)\n", + "evaluator.evaluate()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.11" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/MindEarth/applications/sea/LeadFormer/README.md b/MindEarth/applications/sea/LeadFormer/README.md new file mode 100644 index 0000000000000000000000000000000000000000..23a918c333e804a8aebcb4e493bdc29da260932a --- /dev/null +++ b/MindEarth/applications/sea/LeadFormer/README.md @@ -0,0 +1,87 @@ +# LeadFormer: High-Resolution Intelligent Forecasting of Arctic Sea Ice + +## Overview + +Leads are linear fracture zones formed in sea ice under the influence of waves, wind, and ocean currents. Their morphological + characteristics reflect the intensity of substance and energy exchange between the ocean and the atmosphere, influencing turbulent heat fluxes on the lead surface. + Therefore, accurately characterizing the morphology and spatial distribution of leads is crucial for studying Arctic sea ice changes and predicting navigational routes. + +The morphological features of leads include length, width, and orientation (tilt angle). + +- Lead width largely determines the intensity of heat and moisture exchange between the atmosphere and ocean +- Lead orientation reflects and influences sea ice dynamics +- Total lead length serves as an indicator for measuring scale variations, seasonal changes, and interannual variability of leads + +High-resolution sea ice lead forecasting models are key technological tools for addressing the rapid changes in Arctic sea ice under global warming. + To tackle the complexity of sea ice change mechanisms and the uncertainty in sea ice forecasting, ***LeadFormer*** leverages Arctic high-resolution + numerical model data and a Transformer-based artificial intelligence model. It achieves intelligent forecasting of Arctic leads, covering the pan-Arctic + region with a high-resolution ice condition forecasting system at 2 km resolution. + +The model framework is shown in the figure below: + +![LeadFormer](images/model.png) + +The model adopts an encoder-decoder framework: + +- **Encoding stage**: Compresses and deepens features through overlapping block embedding and a four-level Transformer block structure +- **Decoding stage**: Gradually reconstructs spatial dimensions via MLP (Multi-Layer Perceptron) and upsampling operations +- **Core innovation**: Fuses global modeling capability of Transformers with local perception characteristics of CNNs (Convolutional Neural Networks), + making it suitable for high-precision image processing tasks + +The dataset for this model is currently not open-source; only the code is open-sourced. + +## Quick Start + +Prepare your data, then modify the `data_path` in `./configs/2km_ice_config.yaml` (data not currently open source). + +### Running Method: Call the `main` script from the command line + +```python +python main.py --device_id 0 --device_target Ascend --cfg ./configs/diffusion_cfg.yaml --mode train +``` + +Where: + +- `--device_target` indicates the device type, default is Ascend. +- `--device_id` indicates the ID of the running device, default is 0. +- `--cfg` is the path to the configuration file, default is "./configs/2km_ice_config.yaml". +- `--mode` is the running mode, default is train. + +### Inference + +Set `model_checkpoint` in `./configs/2km_ice_config.yaml` to the path of the diffusion model checkpoint. + +```python +python main.py --device_id 0 --mode test +``` + +### Result Display + +#### Prediction Result Visualization + +The following figure shows the results obtained after training with 728 samples for 30 epochs and then performing inference. +In the figure, the black outlines represent the topography, and the colored bands represent the prediction results. + +![LeadFormer](images/result.jpg) + +### Performance + +| Parameter | NPU | +|:----------------------:|:--------------------------:| +| Hardware Version | Ascend, 64G | +| MindSpore Version | 2.5.0 | +| Dataset | Polar Region Images | +| Training Parameters | batch_size=1, steps_per_epoch=728, epochs=30 | +| Testing Parameters | batch_size=1, steps=44 | +| Optimizer | AdamW | +| Training Loss (RMSE) | 0.07727 | +| Lead Detection Prediction Accuracy (Acc) | 98.90112% | +| Lead Length Deviation | 0.09848% | +| Lead Angle Deviation | 6.27244° | +| Lead Width Deviation | 1.21519% | +| Training Resources | 1 Node 8 NPUs | + +## Contributors + +**gitee id**: funfunplus +**email**: funniless@163.com \ No newline at end of file diff --git a/MindEarth/applications/sea/LeadFormer/README_CN.md b/MindEarth/applications/sea/LeadFormer/README_CN.md new file mode 100644 index 0000000000000000000000000000000000000000..bc9362db237b05ade3f00f001db5134424cba69a --- /dev/null +++ b/MindEarth/applications/sea/LeadFormer/README_CN.md @@ -0,0 +1,68 @@ +# LeadFormer: 北极海冰高分辨率智能预报 + +## 概述 + +冰间水道是海冰在海浪、风力和洋流作用下形成的线状断裂带,其形态特征能够反映海洋与大气之间物质能量交换的强度,影响着水道表面的湍流热通量。因此,冰间水道的形态及空间分布的准确刻画对研究北极的海冰变化和预测航道通航具有重要意义。冰间水道的形态特征包括长度、宽度和倾角等。冰间水道宽度在一定程度上决定了大气和海洋水热交换的强度,水道倾角反应且影响海冰动力学特征,水道总长度可以作为衡量冰间水道尺度变异及季节和年际变化的指标。高分辨率海冰冰间水道预测模型是当前应对全球气候变暖背景下北极海冰快速变化的关键技术工具。针对海冰变化机理的复杂性和海冰预报的不确定性,***LeadFormer***以北极高分辨率数值模式数据和基于transformer的人工智能模型为支撑,实现北极冰间水道的智能预报,区域覆盖泛北极,分辨率达到2km的高分辨率冰情预报体系。 +模型框架图入下图所示 + +![LeadFormer](images/model.png) + +该模型采用编码器-解码器框架,编码阶段通过重叠块嵌入和四级Transformer块实现特征压缩与深化;解码阶段通过MLP和上采样操作逐步重建空间维度;核心创新在于融合Transformer的全局建模能力与CNN的局部感知特性,适用于高精度图像处理任务。 + +本模型数据集暂不开源,仅开源代码。 + +## 快速开始 + +准备数据,然后在`./configs/2km_ice_config.yaml`中修改`data_path`路径(暂不开源数据)。 + +### 运行方式: 在命令行调用`main`脚本 + +```python + +python main.py --device_id 0 --device_target Ascend --cfg ./configs/diffusion_cfg.yaml --mode train + +``` + +其中, --device_target 表示设备类型,默认Ascend。 --device_id 表示运行设备的编号,默认值0。 --cfg 配置文件路径,默认值"./configs/2km_ice_config.yaml"。 --mode 运行模式,默认值train + +### 推理 + +在`./configs/2km_ice_config.yaml`中设置`model_checkpoint`为diffusion模型ckpt地址。 + +```python + +python main.py --device_id 0 --mode test + +``` + +### 结果展示: + +#### 预测结果可视化 + +下图展示了使用728条样本训练30个epoch后进行推理绘制的结果。 +图中,黑色轮廓为地形,彩色条纹为预测结果。 + +![LeadFormer](images/result.jpg) + +### 性能 + +| Parameter | NPU | +|:----------------------:|:--------------------------:| +| 硬件版本 | Ascend, 64G | +| mindspore版本 | 2.5.0 | +| 数据集 | 极区图像 | +| 训练参数 | batch_size=1, steps_per_epoch=728, epochs=30 | +| 测试参数 | batch_size=1,steps=44 | +| 优化器 | AdamW | +| 训练损失(RMSE) | 0.07727 | +| 冰间水道识别预报准确率(Acc) | 98.90112% | +| 冰间水道长度偏差 | 0.09848% | +| 冰间水道角度偏差 | 6.27244° | +| 冰间水道宽度偏差 | 1.21519% | +| 训练资源 | 1Node 8NPU | + +## 贡献者 + +gitee id: Zhou Chuansai, funfunplus + +email: chuansaizhou@163.com, funniless@163.com \ No newline at end of file diff --git a/MindEarth/applications/sea/LeadFormer/configs/2km_ice_config.yaml b/MindEarth/applications/sea/LeadFormer/configs/2km_ice_config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..0d44b2f431536f5fbfdc2e2a8366d940e8faff8c --- /dev/null +++ b/MindEarth/applications/sea/LeadFormer/configs/2km_ice_config.yaml @@ -0,0 +1,46 @@ +model: + name: "ice_simple" + include_background: True + num_classes: 1 + num_channels: 1 + filter_weight: ["outc.weight", "outc.bias"] + in_channels: [64, 128, 320, 512] + +data: + data_path: "./test_dataset" + dataset: "2km_ice" + crop: None + train_augment: False + batch_size: 1 + cross_valid_ind: 1 + eval_resize: False + data_shape: [3360, 3072] + +optimizer: + lr: 0.0001 + weight_decay: 0.01 + loss_scale: 1024.0 + FixedLossScaleManager: 1024.0 + +summary: + output_path: "./train" + checkpoint_path: "./output/ckpt" + keep_checkpoint_max: 1 + show_eval: True + load_path: "./ckpt_0/best.ckpt" + +train: + device_target: "Ascend" + epochs: 300 + repeat: 1 + distribute_epochs: 300 + run_distribute: False + resume: False + resume_ckpt: "./" + transfer_training: False + enable_profiling: False + amp_level: "O3" + +test: + model_checkpoint: "./ckpt" + output_path: "./output/" diff --git a/MindEarth/applications/sea/LeadFormer/images/model.png b/MindEarth/applications/sea/LeadFormer/images/model.png new file mode 100644 index 0000000000000000000000000000000000000000..586b7051c5863d420e0e9d13b238d8d4dc9df1c0 Binary files /dev/null and b/MindEarth/applications/sea/LeadFormer/images/model.png differ diff --git a/MindEarth/applications/sea/LeadFormer/images/result.jpg b/MindEarth/applications/sea/LeadFormer/images/result.jpg new file mode 100644 index 0000000000000000000000000000000000000000..950fd91fde598115d4c4f3fad606238a9fda02c2 Binary files /dev/null and b/MindEarth/applications/sea/LeadFormer/images/result.jpg differ diff --git a/MindEarth/applications/sea/LeadFormer/main.py b/MindEarth/applications/sea/LeadFormer/main.py new file mode 100644 index 0000000000000000000000000000000000000000..e3370b5a0149d4c88c605dae6e3c29256d380088 --- /dev/null +++ b/MindEarth/applications/sea/LeadFormer/main.py @@ -0,0 +1,71 @@ +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"main" +import argparse + +import mindspore as ms +from mindspore import set_seed, context +from mindearth.utils import load_yaml_config + +from src.solver import Trainer +from src.forecast import Tester + +set_seed(0) + + +def get_parser(): + parser = argparse.ArgumentParser() + parser.add_argument("--device_id", default=5, type=int) + parser.add_argument("--device_target", default="Ascend", type=str) + parser.add_argument("--cfg", default="./configs/2km_ice_config.yaml", type=str) + parser.add_argument("--mode", default="test") + params = parser.parse_args() + return params + + +def train(cfg): + """train""" + epoch_size = ( + cfg["train"].get("epochs", 300) + if not cfg["train"].get("run_distribute", False) + else cfg["train"].get("distribute_epochs", 300) + ) + trainer = Trainer(cfg, epochs=epoch_size) + trainer.train() + + +def test(cfg): + """test""" + evaluator = Tester(cfg) + evaluator.evaluate() + + +if __name__ == "__main__": + args = get_parser() + config = load_yaml_config(args.cfg) + print("config", config) + if args.mode == "test": + context.set_context(mode=ms.PYNATIVE_MODE) + elif args.mode == "train": + context.set_context(mode=ms.GRAPH_MODE) + else: + raise ValueError(f"Invalid mode: '{args.mode}'. Expected 'test' or 'train'.") + ms.set_device(device_target=args.device_target, device_id=args.device_id) + if args.mode == "train": + train(config) + elif args.mode == "test": + test(config) + else: + raise ValueError(f"Invalid mode: '{args.mode}'. Expected 'test' or 'train'.") diff --git a/MindEarth/applications/sea/LeadFormer/scripts/run_distribute_train.sh b/MindEarth/applications/sea/LeadFormer/scripts/run_distribute_train.sh new file mode 100644 index 0000000000000000000000000000000000000000..3d96d8de47b97d7b631367ef6f2fe9e143b08c35 --- /dev/null +++ b/MindEarth/applications/sea/LeadFormer/scripts/run_distribute_train.sh @@ -0,0 +1,65 @@ +#!/bin/bash +# Copyright 2020-2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +get_real_path() { + if [ "${1:0:1}" == "/" ]; then + echo "$1" + else + echo "$(realpath -m $PWD/$1)" + fi +} + +if [ $# != 3 ] +then + echo "==============================================================================================================" + echo "Usage: bash scripts/run_distribute_train.sh [RANK_TABLE_FILE] [DATASET] [CONFIG_PATH]" + echo "Please run the script as: " + echo "bash scripts/run_distribute_train.sh [RANK_TABLE_FILE] [DATASET] [CONFIG_PATH]" + echo "for example: bash run_distribute_train.sh /absolute/path/to/RANK_TABLE_FILE /absolute/path/to/data /absolute/path/to/config" + echo "==============================================================================================================" + exit 1 +fi +PROJECT_DIR=$(cd "$(dirname "$0")" || exit; pwd) +export HCCL_CONNECT_TIMEOUT=600 +export RANK_SIZE=8 +DATASET=$(get_real_path $2) +CONFIG_PATH=$(get_real_path $3) +RANK_TABLE=$(get_real_path $1) +export RANK_TABLE_FILE=$RANK_TABLE + +ulimit -u unlimited + +for((i=0;i env.log + + python ${PROJECT_DIR}/../train.py \ + --run_distribute=True \ + --data_path=$DATASET \ + --config_path=$CONFIG_PATH \ + --output_path './output' > log.txt 2>&1 & + + cd ../ +done diff --git a/MindEarth/applications/sea/LeadFormer/src/__init__.py b/MindEarth/applications/sea/LeadFormer/src/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/MindEarth/applications/sea/LeadFormer/src/backbone.py b/MindEarth/applications/sea/LeadFormer/src/backbone.py new file mode 100644 index 0000000000000000000000000000000000000000..7db2b578a24eecdcb413f718d363ba45db3630e7 --- /dev/null +++ b/MindEarth/applications/sea/LeadFormer/src/backbone.py @@ -0,0 +1,509 @@ +# right 2020-2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""model""" +import math + +import mindspore.nn as nn +import mindspore.ops as ops +import mindspore.numpy as msnp +from mindspore.ops import linspace +from mindspore.common.initializer import initializer, TruncatedNormal, Normal + + +def init_weights(m): + """init_weights""" + if isinstance(m, nn.Dense): + m.weight.set_data( + initializer(TruncatedNormal(0.02), m.weight.shape, m.weight.dtype) + ) + if m.bias is not None: + m.bias.set_data(initializer("zeros", m.bias.shape, m.bias.dtype)) + elif isinstance(m, nn.LayerNorm): + m.gamma.set_data(initializer("ones", m.gamma.shape, m.gamma.dtype)) + m.beta.set_data(initializer("zeros", m.beta.shape, m.beta.dtype)) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.group + m.weight.set_data( + initializer(Normal(math.sqrt(2.0 / fan_out)), m.weight.shape, m.weight.dtype) + ) + if m.bias is not None: + m.bias.set_data(initializer("zeros", m.bias.shape, m.bias.dtype)) + + +class OverlapPatchEmbed(nn.Cell): + """ + Implements overlapping patch embedding with convolutional projection. + + This module splits input images into overlapping patches using a convolutional + projection layer, then applies layer normalization. Designed for vision transformers + with overlapping patch strategies. + + Args: + patch_size (int): Size of the sliding window (kernel size). Default: 7 + stride (int): Stride for convolution operation. Controls patch overlap. Default: 4 + in_chans (int): Number of input channels. Default: 3 (RGB) + embed_dim (int): Dimension of embedding output. Default: 768 + + Shape: + Input: (B, C, H, W) where B=batch_size, C=in_chans + Output: (B, num_patches, embed_dim) + """ + def __init__(self, patch_size=7, stride=4, in_chans=3, embed_dim=768): + super().__init__() + patch_size = (patch_size, patch_size) + self.proj = nn.Conv2d( + in_chans, + embed_dim, + kernel_size=patch_size, + stride=stride, + padding=( + patch_size[0] // 2, + patch_size[1] // 2, + patch_size[0] // 2, + patch_size[1] // 2, + ), + pad_mode="pad", + has_bias=True, + bias_init=None, + ) + self.norm = nn.LayerNorm((embed_dim,)) + self.apply(init_weights) + + def construct(self, x): + x = self.proj(x) + b, c, h, w = x.shape + + x = x.reshape(b, c, h * w) + x = ops.swapaxes(x, 1, 2) + x = self.norm(x) + + return x + + +class Attention(nn.Cell): + """ + Implements multi-head self-attention with spatial reduction. + + This attention module optionally incorporates spatial reduction (SR) to reduce + computational complexity for high-resolution inputs. When sr_ratio > 1, it applies + convolution to reduce spatial dimensions before key/value computation. + + Args: + dim (int): Input feature dimension + num_heads (int): Number of attention heads. Default: 8 + qkv_bias (bool): Enable bias for query/key/value projections. Default: False + qk_scale (float): Override default scale factor (1/sqrt(d_k)). Default: None + attn_drop (float): Attention dropout probability. Default: 0.0 + proj_drop (float): Output projection dropout probability. Default: 0.0 + sr_ratio (int): Spatial reduction ratio. Default: 1 (no reduction) + + Shape: + Input: (B, N, C) where: + B = batch size + N = sequence length (H * W) + C = feature dimension (dim) + Output: (B, N, C) (same shape as input) + """ + def __init__( + self, + dim, + num_heads=8, + qkv_bias=False, + qk_scale=None, + attn_drop=0.0, + proj_drop=0.0, + sr_ratio=1, + ): + super().__init__() + assert ( + dim % num_heads == 0 + ), f"dim {dim} should be divided by num_heads {num_heads}." + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim**-0.5 + self.q = nn.Dense(dim, dim, has_bias=qkv_bias) + self.sr_ratio = sr_ratio + if sr_ratio > 1: + self.sr = nn.Conv2d( + dim, dim, kernel_size=sr_ratio, has_bias=True, stride=sr_ratio + ) + self.norm = nn.LayerNorm(normalized_shape=(dim,)) + self.kv = nn.Dense(dim, dim * 2, has_bias=qkv_bias) + self.attn_drop = nn.Dropout(p=attn_drop) + self.proj = nn.Dense(dim, dim) + self.proj_drop = nn.Dropout(p=proj_drop) + self.apply(init_weights) + + def construct(self, x, h, w): + """construct""" + b, n, c = x.shape + q = ( + self.q(x) + .reshape(b, n, self.num_heads, c // self.num_heads) + .permute(0, 2, 1, 3) + ) + + if self.sr_ratio > 1: + x_ = ops.permute(x, (0, 2, 1)) + x_ = x_.reshape(b, c, h, w) + x_ = self.sr(x_).reshape(b, c, -1).permute(0, 2, 1) + x_ = self.norm(x_) + kv = ( + self.kv(x_) + .reshape(b, -1, 2, self.num_heads, c // self.num_heads) + .permute(2, 0, 3, 1, 4) + ) + else: + kv = ( + self.kv(x) + .reshape(b, -1, 2, self.num_heads, c // self.num_heads) + .permute(2, 0, 3, 1, 4) + ) + k, v = kv[0], kv[1] + + attn = (q @ k.swapaxes(-2, -1)) * self.scale + attn = nn.Softmax(axis=-1)(attn) + attn = self.attn_drop(attn) + + x = (attn @ v).swapaxes(1, 2).reshape(b, n, c) + x = self.proj(x) + x = self.proj_drop(x) + + return x + + +def drop_path( + x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True +): + """ + Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, + the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for + changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use + 'survival rate' as the argument. + """ + if drop_prob == 0.0 or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) + if keep_prob > 0.0 and scale_by_keep: + mask = keep_prob + ops.rand(shape, dtype=x.dtype) + mask = msnp.floor(mask) + x = ops.div(x, keep_prob) * mask + return x + + +class DropPath(nn.Cell): + """ + Actual implementation of stochastic depth (drop path) regularization. + + Args: + x (Tensor): Input tensor + drop_prob (float): Probability of dropping a sample + training (bool): Whether in training mode + scale_by_keep (bool): Whether to scale outputs by keep probability + + Returns: + Tensor: Output tensor after applying drop path + """ + def __init__(self, drop_prob=None, scale_by_keep=True): + super().__init__() + self.drop_prob = drop_prob + self.scale_by_keep = scale_by_keep + + def construct(self, x): + return drop_path(x, self.drop_prob, self.training, self.scale_by_keep) + + +class DWConv(nn.Cell): + """ + Depthwise Convolution Layer + + This module implements a depthwise separable convolution operation. + Depthwise convolution applies a single filter per input channel, making it + computationally efficient while still capturing spatial features. + + Args: + dim (int): Number of input/output channels (default: 768) + """ + def __init__(self, dim=768): + super().__init__() + self.dwconv = nn.Conv2d( + dim, + dim, + kernel_size=3, + stride=1, + pad_mode="pad", + padding=1, + has_bias=True, + group=dim, + ) + + def construct(self, x, h, w): + """construct""" + b, _, c = x.shape + x = x.swapaxes(1, 2).view(b, c, h, w) + x = self.dwconv(x) + x = x.flatten(start_dim=2).swapaxes(1, 2) + + return x + + +class Mlp(nn.Cell): + """ + Multi-Layer Perceptron with Depthwise Convolution + + This module implements a multi-layer perceptron with an intermediate depthwise convolution layer. + It consists of two fully connected layers with an activation function and dropout in between. + A depthwise convolution is applied after the first fully connected layer to incorporate spatial information. + + Args: + in_features (int): Number of input features + hidden_features (int): Number of hidden features (default: None, same as in_features) + out_features (int): Number of output features (default: None, same as in_features) + act_layer (nn.Cell): Activation function (default: nn.GELU) + drop (float): Dropout probability (default: 0.0) + """ + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + drop=0.0, + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + + self.fc1 = nn.Dense(in_features, hidden_features) + self.dwconv = DWConv(hidden_features) + self.act = act_layer() + + self.fc2 = nn.Dense(hidden_features, out_features) + + self.drop = nn.Dropout(p=drop) + self.apply(init_weights) + + def construct(self, x, h, w): + """construct""" + x = self.fc1(x) + x = self.dwconv(x, h, w) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class Block(nn.Cell): + """ + Generic Transformer Block for all stages of the network architecture. + + Args: + dim (int): Dimension of input features + num_heads (int): Number of attention heads + spatial_dims (tuple): Spatial dimensions (H, W) for current stage + mlp_ratio (float): Ratio for hidden dimension in MLP (default: 4.0) + qkv_bias (bool): Whether to include bias in QKV projection (default: False) + qk_scale (float): Scaling factor for QK dot product (default: None) + drop (float): Dropout rate for projection layers (default: 0.0) + attn_drop (float): Dropout rate for attention weights (default: 0.0) + drop_path (float): Drop path probability for this block (default: 0.0) + act_layer (nn.Cell): Activation function (default: nn.GELU) + sr_ratio (int): Spatial reduction ratio for attention (default: 1) + """ + def __init__( + self, + dim, + num_heads, + spatial_dims, + qkv_bias=False, + qk_scale=None, + drop=0.0, + attn_drop=0.0, + block_drop_path=0.0, + act_layer=nn.GELU, + sr_ratio=1, + mlp_ratio=4. + ): + super().__init__() + self.norm1 = nn.LayerNorm((dim,)) + self.attn = Attention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=drop, + sr_ratio=sr_ratio, + ) + self.norm2 = nn.LayerNorm((dim,)) + self.mlp = Mlp( + in_features=dim, + hidden_features=int(dim * mlp_ratio), + act_layer=act_layer, + drop=drop, + ) + self.drop_path = ( + DropPath(drop_prob=block_drop_path) if block_drop_path > 0.0 else nn.Identity() + ) + self.spatial_dims = spatial_dims + self.apply(init_weights) + + def construct(self, x): + h, w = self.spatial_dims + x = self.norm1(x) + x = x + self.attn(x, h, w) + x = x + self.mlp(self.norm2(x), h, w) + return x + + +class MixVisionTransformer(nn.Cell): + """Hierarchical Vision Transformer backbone with multi-scale feature extraction. + + Implements a multi-stage transformer architecture with spatial reduction for + efficient high-resolution processing. Produces feature maps at multiple scales + suitable for dense prediction tasks like segmentation. + + Args: + in_chans (int): Input channels. Default: 6 (RGB + auxiliary) + num_classes (int): Output classes for classification head. Default: 1000 + embed_dims (list): Feature dimensions for each stage. Default: [64, 128, 320, 512] + num_heads (list): Attention heads per stage. Default: [1, 2, 4, 8] + mlp_ratios (list): MLP expansion ratios. Default: [4, 4, 4, 4] + qkv_bias (bool): Enable bias in QKV projections. Default: False + qk_scale (float): Custom QK scaling factor. Default: None + drop_rate (float): General dropout rate. Default: 0.0 + attn_drop_rate (float): Attention dropout rate. Default: 0.0 + drop_path_rate (float): Stochastic depth rate. Default: 0.0 + depths (list): Number of transformer blocks per stage. Default: [3, 4, 6, 3] + sr_ratios (list): Spatial reduction ratios per stage. Default: [8, 4, 2, 1] + + Input Shape: + (B, in_chans, H, W) # Typically 512x512 for segmentation tasks + + Output Shape: + List of 4 feature maps at different scales: + [ + (B, embed_dims[0], H/4, W/4), # Stage1: 128x128 + (B, embed_dims[1], H/8, W/8), # Stage2: 64x64 + (B, embed_dims[2], H/16, W/16), # Stage3: 32x32 + (B, embed_dims[3], H/32, W/32) # Stage4: 16x16 + ] + """ + def __init__( + self, + in_chans=6, + num_classes=1000, + embed_dims=None, + num_heads=None, + mlp_ratios=None, + qkv_bias=False, + qk_scale=None, + drop_rate=0.0, + attn_drop_rate=0.0, + drop_path_rate=0.0, + norm_layer=nn.LayerNorm, + depths=None, + sr_ratios=None, + high=None, + weight=None + ): + super().__init__() + self.embed_dims = embed_dims if embed_dims is not None else [64, 128, 320, 512] + self.num_heads = num_heads if num_heads is not None else [1, 2, 4, 8] + self.mlp_ratios = mlp_ratios if mlp_ratios is not None else [4, 4, 4, 4] + self.sr_ratios = sr_ratios if sr_ratios is not None else [8, 4, 2, 1] + self.num_classes = num_classes + self.depths = depths if depths is not None else [3, 4, 6, 3] + self.h = high if high is not None else [500, 250, 125, 63] + self.w = weight if weight is not None else [500, 250, 125, 63] + dpr = [x.item() for x in linspace(0, drop_path_rate, sum(self.depths))] + self.patch_embed1 = OverlapPatchEmbed( + patch_size=7, stride=4, in_chans=in_chans, embed_dim=self.embed_dims[0] + ) + current_depth = 0 + current_channels = in_chans + + for i in range(4): + patch_embed, blocks, norm = self.create_stage(i, current_channels, current_depth, + qkv_bias, qk_scale, drop_rate, + attn_drop_rate, dpr, norm_layer) + setattr(self, f'patch_embed{i+1}', patch_embed) + setattr(self, f'block{i+1}', blocks) + setattr(self, f'norm{i+1}', norm) + current_channels = self.embed_dims[i] + current_depth += self.depths[i] + self.apply(init_weights) + + def create_stage(self, index, in_chans, cur_depth, + qkv_bias, qk_scale, drop_rate, attn_drop_rate, dpr, norm_layer): + """create stage""" + patch_sizes = [7, 3, 3, 3] + strides = [4, 2, 2, 2] + patch_embed = OverlapPatchEmbed( + patch_size=patch_sizes[index], + stride=strides[index], + in_chans=in_chans, + embed_dim=self.embed_dims[index] + ) + blocks = nn.CellList([ + Block( + dim=self.embed_dims[index], + spatial_dims=(self.h[index], self.w[index]), + num_heads=self.num_heads[index], + mlp_ratio=self.mlp_ratios[index], + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + block_drop_path=dpr[cur_depth + i], + sr_ratio=self.sr_ratios[index], + ) + for i in range(self.depths[index]) + ]) + norm = norm_layer((self.embed_dims[index],)) + + return patch_embed, blocks, norm + + def _forward_stage(self, x, patch_embed, blocks, norm, h, w, b): + x = patch_embed(x) + for blk in blocks: + x = blk(x) + x = norm(x) + x = x.reshape(b, h, w, -1).permute(0, 3, 1, 2) + return x + + def construct(self, x): + """construct""" + b = x.shape[0] + outs = [] + + stages = [ + (self.patch_embed1, self.block1, self.norm1, self.h[0], self.w[0]), + (self.patch_embed2, self.block2, self.norm2, self.h[1], self.w[1]), + (self.patch_embed3, self.block3, self.norm3, self.h[2], self.w[2]), + (self.patch_embed4, self.block4, self.norm4, self.h[3], self.w[3]), + ] + + for patch_embed, blocks, norm, h, w in stages: + x = self._forward_stage(x, patch_embed, blocks, norm, h, w, b) + outs.append(x) + + return outs diff --git a/MindEarth/applications/sea/LeadFormer/src/data_loader.py b/MindEarth/applications/sea/LeadFormer/src/data_loader.py new file mode 100644 index 0000000000000000000000000000000000000000..433be9f99e5a6a000d46ec4e8b65443702c75b00 --- /dev/null +++ b/MindEarth/applications/sea/LeadFormer/src/data_loader.py @@ -0,0 +1,191 @@ +# right 2020-2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""data_loader""" +import os +from collections import deque + +import numpy as np +from PIL import Image, ImageSequence + +import mindspore.dataset as ds + + +def _load_multipage_tiff(path): + """Load tiff images containing many images in the channel dimension""" + return np.array([np.array(p) for p in ImageSequence.Iterator(Image.open(path))]) + + +def _get_val_train_indices(length, fold, ratio=0.8): + """get_val_train_indices""" + assert 0 < ratio <= 1, "Train/total data ratio must be in range (0.0, 1.0]" + np.random.seed(0) + indices = np.arange(0, length, 1, dtype=np.int32) + np.random.shuffle(indices) + + if fold is not None: + indices = deque(indices) + indices.rotate(fold * round((1.0 - ratio) * length)) + indices = np.array(indices) + train_indices = indices[: round(ratio * len(indices))] + val_indices = indices[round(ratio * len(indices)) :] + else: + train_indices = indices + val_indices = [] + return train_indices, val_indices + + +def data_post_process(img, mask): + """data_post_process""" + img = np.expand_dims(img, axis=0) + mask = (mask > 0.5).astype(np.int32) + mask = (np.arange(mask.max() + 1) == mask[..., None]).astype(int) + mask = mask.transpose(2, 0, 1).astype(np.float32) + + return img, mask + + +def train_data_augmentation(img, mask, size=572): + """train_data_augmentation""" + h_flip = np.random.random() + if h_flip > 0.5: + img = np.flipud(img) + mask = np.flipud(mask) + v_flip = np.random.random() + if v_flip > 0.5: + img = np.fliplr(img) + mask = np.fliplr(mask) + + left = int(np.random.uniform() * 0.3 * size) + right = int((1 - np.random.uniform() * 0.3) * size) + top = int(np.random.uniform() * 0.3 * size) + bottom = int((1 - np.random.uniform() * 0.3) * size) + + img = img[top:bottom, left:right] + mask = mask[top:bottom, left:right] + + brightness = np.random.uniform(-0.2, 0.2) + img = np.float32(img + brightness * np.ones(img.shape)) + img = np.clip(img, -1.0, 1.0) + + return img, mask + + +class MultiClassDataset: + """ + Read image and mask from original images, and split all data into train_dataset and val_dataset by `split`. + Get image path and mask path from a tree of directories, + images within one folder is an image, the image file named `"image.png"`, the mask file named `"mask.png"`. + """ + def __init__(self, data_dir, repeat, is_train=False, split=0.8, shuffle=False): + self.data_dir = f"{data_dir.rstrip('/')}/ice_input" + self.label_dir = f"{data_dir.rstrip('/')}/ice_label" + self.is_train = is_train + self.split = split != 1.0 + if self.split: + self.img_ids = os.listdir(self.data_dir) + self.label_ids = os.listdir(self.label_dir) + self.train_ids = self.img_ids[: int(len(self.img_ids) * split)] * repeat + self.train_label_ids = ( + self.label_ids[: int(len(self.img_ids) * split)] * repeat + ) + self.val_ids = self.img_ids[int(len(self.img_ids) * split) :] + self.val_label_ids = self.label_ids[int(len(self.img_ids) * split) :] + else: + self.train_ids = sorted( + next(os.walk(os.path.join(self.data_dir, "train")))[1] + ) + self.val_ids = sorted(next(os.walk(os.path.join(self.data_dir, "val")))[1]) + if shuffle: + np.random.shuffle(self.train_ids) + + def _read_img_mask(self, img_id, label_id): + """read_img_mask""" + if self.split: + path = os.path.join(self.data_dir, img_id) + label_path = os.path.join(self.label_dir, label_id) + elif self.is_train: + path = os.path.join(self.data_dir, "train", img_id) + else: + path = os.path.join(self.data_dir, "val", img_id) + img = np.load(path, allow_pickle=True) + mask = np.load(label_path, allow_pickle=True) + img = img[:, :, 1:7] + return img, mask + + def __getitem__(self, index): + if self.is_train: + return self._read_img_mask( + self.train_ids[index], self.train_label_ids[index] + ) + return self._read_img_mask(self.val_ids[index], self.val_label_ids[index]) + + @property + def column_names(self): + column_names = ["image", "mask"] + return column_names + + def __len__(self): + if self.is_train: + return len(self.train_ids) + return len(self.val_ids) + + +def preprocess_img_mask(img, mask): + """ + Preprocess for multi-class dataset. + Random crop and flip images and masks when augment is True. + """ + img = img.astype(np.float32) + img = img.transpose(2, 0, 1) + mask = mask.transpose(2, 0, 1).astype(np.float32) + return img, mask + + +def create_multi_class_dataset( + data_dir, + repeat, + batch_size, + is_train=False, + split=0.8, + rank=0, + group_size=1, + shuffle=True, + num_parallel_workers=32 +): + """ + Get generator dataset for multi-class dataset. + """ + ds.config.set_enable_shared_mem(True) + mc_dataset = MultiClassDataset(data_dir, repeat, is_train, split, shuffle) + dataset = ds.GeneratorDataset( + mc_dataset, + mc_dataset.column_names, + shuffle=True, + num_shards=group_size, + shard_id=rank, + num_parallel_workers=num_parallel_workers, + python_multiprocessing=is_train, + ) + compose_map_func = preprocess_img_mask + dataset = dataset.map( + operations=compose_map_func, + input_columns=mc_dataset.column_names, + output_columns=mc_dataset.column_names, + num_parallel_workers=num_parallel_workers, + ) + dataset = dataset.batch( + batch_size, drop_remainder=is_train, num_parallel_workers=num_parallel_workers + ) + return dataset diff --git a/MindEarth/applications/sea/LeadFormer/src/forecast.py b/MindEarth/applications/sea/LeadFormer/src/forecast.py new file mode 100644 index 0000000000000000000000000000000000000000..ecce78409ee3f90002cbc0440c371f7c63069d57 --- /dev/null +++ b/MindEarth/applications/sea/LeadFormer/src/forecast.py @@ -0,0 +1,911 @@ +# right 2020-2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""inference""" +import os +from math import radians, sin, cos, degrees, atan2 + +import cv2 +import haversine.haversine +import numpy as np +import matplotlib.cm as cm +import matplotlib.pyplot as plt +from mpl_toolkits.basemap import Basemap +from tqdm import tqdm +from skimage import morphology + +import mindspore as ms +from mindspore import Tensor +from mindspore.train.serialization import load_checkpoint, load_param_into_net + +from .segformer import SegFormer +from .utils import make_grid, warp, readbin + + +class Tester: + """ + Sea Ice Segmentation Model Evaluator + + This class is used to evaluate the performance of sea ice segmentation models, including: + - Loading models and checkpoints + - Processing input data and making predictions + - Calculating various evaluation metrics + - Visualizing prediction results + - Assessing the detection performance of Linear Kinematic Features (LKF) + + Attributes: + config (dict): Configuration dictionary containing parameters such as model paths and data paths + model (SegFormer): Instance of the segmentation model + rmse_all (float): Cumulative Root Mean Square Error + max_rmse (float): Maximum Root Mean Square Error + output_dir_pred (str): Path for saving prediction results + output_dir_mask (str): Path for saving label results + """ + def __init__(self, config): + """ + Initialize the IceSegmentationEvaluator class. + + Args: + config (dict): Configuration dictionary containing: + - model_checkpoint: Path to the model checkpoint file + - test_input_root: Path to test input data + - test_label_root: Path to test label data + - output_path: Path to save output + """ + self.config = config + self.in_channels = self.config["model"].get("in_channels") + # Initialize model + self.model = SegFormer( + in_channels=self.in_channels, + num_classes=1, + embedding_dim=256, + ) + self._load_checkpoint() + self.rmse_all = 0 + self.max_rmse = 0 + + def _load_checkpoint(self): + """Load model checkpoint.""" + param_dict = load_checkpoint(self.config["test"].get("model_checkpoint")) + missing_keys, unexpected_keys = load_param_into_net(self.model, param_dict) + print("=" * 50) + print("Missing keys:", missing_keys) + print("Unexpected keys:", unexpected_keys) + print("Checkpoint loaded successfully!") + print("=" * 50) + # Print total parameters + total_params = sum(param.size for param in self.model.get_parameters()) + print(f"Total Parameters: {total_params}") + + def _process_input(self, inp_path, label_path): + """Process input and label data.""" + # Load data + inp = np.load(inp_path, allow_pickle=True) + labels = np.load(label_path, allow_pickle=True) + # Transpose and reshape input + inp = inp.transpose(2, 0, 1) + ori_inp = inp[1, :, :] + inp = np.expand_dims(inp[1:7, :, :], axis=0) + inp = Tensor(inp, ms.float32) + # Process labels + labels = labels.transpose(2, 0, 1) + labels = np.expand_dims(labels, axis=0) + + return inp, labels, ori_inp + + def _make_prediction(self, inp, labels): + """Make prediction using the model.""" + intensity, motion = self.model(inp) + batch, _, height, width = inp.shape + # Reshape motion and intensity + motion_ = motion.reshape(batch, 1, 2, height, width) + intensity_ = intensity.reshape(batch, 1, 1, height, width) + # Process last frames + last_frames = Tensor(labels[:, 0, :, :], ms.float32).unsqueeze(dim=0) + # Create grid for warping + sample_tensor = np.zeros((1, 1, 2000, 2000)).astype(np.float32) + grid = Tensor(make_grid(sample_tensor), ms.float32) + my_grid = grid.tile((batch, 1, 1, 1)) + # Warp and predict + last_frames = warp( + last_frames, motion_[:, 0], my_grid, mode="nearest", padding_mode="border" + ) + tmp = last_frames + last_frames = last_frames + intensity_[:, 0] + pred = last_frames + + return pred, tmp + + def _calculate_metrics(self, pred, labels, ori_inp): + """Calculate evaluation metrics.""" + minus = pred[0, 0, :, :].asnumpy().T - labels[0, 1, :, :].T + minus_consist = np.squeeze(pred[0, 0, :, :].asnumpy()).T - np.squeeze(ori_inp).T + minus_label = np.squeeze(ori_inp).T - np.squeeze(labels[0, 1, :, :]).T + rmse = np.sqrt(np.mean(np.square(minus))) + return rmse, minus, minus_consist, minus_label + + def _save_results(self, pred_img, label_img, file_name): + """Save prediction and label results.""" + self.output_dir_pred = os.path.join(self.config["test"].get("output_path"), "pred") + os.makedirs(self.output_dir_pred, exist_ok=True) + self.output_dir_mask = os.path.join(self.config["test"].get("output_path"), "mask") + os.makedirs(self.output_dir_mask, exist_ok=True) + np.save(os.path.join(self.output_dir_pred, file_name), pred_img) + np.save(os.path.join(self.output_dir_mask, file_name), label_img) + + def _visualize_results( + self, ori_inp, pred, labels, minus, minus_consist, file_name + ): + """Visualize and save results.""" + plt.figure(figsize=(15, 10)) + + # Input + plt.subplot(231) + plt.pcolormesh(ori_inp.T, cmap=cm.gist_ncar_r, vmax=5, vmin=0) + plt.colorbar() + plt.xticks([]) + plt.yticks([]) + plt.title("input") + + # Prediction + plt.subplot(232) + plt.pcolormesh( + pred[0, 0, :, :].asnumpy().T, cmap=cm.gist_ncar_r, vmax=5, vmin=0 + ) + plt.colorbar() + plt.xticks([]) + plt.yticks([]) + plt.title("pred") + + # Label + plt.subplot(233) + plt.pcolormesh(labels[0, 1, :, :].T, cmap=cm.gist_ncar_r, vmax=5, vmin=0) + plt.colorbar() + plt.xticks([]) + plt.yticks([]) + plt.title("label") + + # Prediction - Label + plt.subplot(234) + plt.pcolormesh(minus, cmap=cm.RdBu_r, vmax=0.2, vmin=-0.2) + plt.colorbar() + plt.xticks([]) + plt.yticks([]) + plt.title("pred - label") + + # Prediction - Input + plt.subplot(235) + plt.pcolormesh(minus_consist, cmap=cm.RdBu_r, vmax=0.05, vmin=-0.05) + plt.colorbar() + plt.xticks([]) + plt.yticks([]) + plt.title("pred - input") + + # Label - Input + plt.subplot(236) + plt.pcolormesh( + labels[0, 1, :, :].T - ori_inp.T, cmap=cm.RdBu_r, vmax=0.2, vmin=-0.2 + ) + plt.colorbar() + plt.xticks([]) + plt.yticks([]) + plt.title("label - input") + + plt.tight_layout() + output_dir = os.path.join(self.config["test"].get("output_path"), "pig") + os.makedirs(output_dir, exist_ok=True) + plt.savefig(os.path.join(output_dir, file_name[:-4] + ".png"), dpi=600) + plt.close() + + def evaluate(self): + """Evaluate the model on test dataset.""" + output_dir_input = os.path.join(self.config["data"].get("data_path"), "ice_input") + os.makedirs(output_dir_input, exist_ok=True) + output_dir_label = os.path.join(self.config["data"].get("data_path"), "ice_label") + os.makedirs(output_dir_label, exist_ok=True) + test_files = os.listdir(output_dir_input) + print("Test dataset size:", len(test_files)) + + for file in test_files: + # Prepare paths + inp_path = os.path.join(output_dir_input, file) + label_path = os.path.join( + output_dir_label, file.split(".")[0] + ".label.npy" + ) + # Process input + inp, labels, ori_inp = self._process_input(inp_path, label_path) + # Make prediction + pred, _ = self._make_prediction(inp, labels) + # Calculate metrics + rmse, minus, minus_consist, _ = self._calculate_metrics( + pred, labels, ori_inp + ) + # Update metrics + self.rmse_all += rmse + if rmse > self.max_rmse: + self.max_rmse = rmse + print(file, "RMSE:", rmse) + # Save results + pred_img = pred[0, 0, :, :].asnumpy() + label_img = labels[0, 1, :, :] + self._save_results(pred_img, label_img, file) + # Visualize results + self._visualize_results(ori_inp, pred, labels, minus, minus_consist, file) + + # Print final metrics + avg_rmse = self.rmse_all / len(test_files) + print("Average RMSE:", avg_rmse) + print("Maximum RMSE:", self.max_rmse) + print("Start evaluate!") + self.evaluate_width_mask() + self.evaluate_width_pred() + self.evaluate_lead_all() + self.evaluate_acc() + print("Evaluation completed!") + + def to_lonlat_bin(self, lon, lat, x, y): + lat = np.asarray(lat) + lon = np.asarray(lon) + return lon[x][y], lat[x][y] + + def calconnectivity(self, target, x, y): + """calconnectivity""" + connectivity = 0 + is_special = 0 + is_endpoint = 0 + l_neighbor = [] + p_neighbor = [] + p_neighbor.append(target[x - 1][y - 1]) + p_neighbor.append(target[x][y - 1]) + p_neighbor.append(target[x + 1][y - 1]) + p_neighbor.append(target[x - 1][y]) + p_neighbor.append(target[x + 1][y]) + p_neighbor.append(target[x - 1][y + 1]) + p_neighbor.append(target[x][y + 1]) + p_neighbor.append(target[x + 1][y + 1]) + for i in range(8): + if p_neighbor[i] != 0: + connectivity += 1 + if connectivity >= 3: + is_special += 1 + elif connectivity == 1: + is_endpoint += 1 + p_neighbor.sort() + for j in range(len(p_neighbor) - 1): + if p_neighbor[j + 1] != p_neighbor[j]: + l_neighbor.append(p_neighbor[j + 1]) + return is_special, is_endpoint, l_neighbor + + def breakup(self, num_labels, stats, labels): + """break up""" + excption = 0 + for area in range(num_labels - 1): + current_label_x = [] + current_label_y = [] + for area_x in range(stats[area + 1][2]): + for area_y in range(stats[area + 1][3]): + if labels[stats[area + 1][1] + area_y][ + stats[area + 1][0] + area_x + ] == (area + 1): + current_label_x.append(stats[area + 1][1] + area_y) + current_label_y.append(stats[area + 1][0] + area_x) + special = 0 + special_points = [] + endpoint = 0 + for p in range(len(current_label_x)): + is_special, is_endpoint, _ = self.calconnectivity( + labels, current_label_x[p], current_label_y[p] + ) + if is_special == 1: + special += 1 + special_points.append(p) + if is_endpoint == 1: + endpoint += 1 + + if special != 0: + excption = special + + for p in special_points: + labels[current_label_x[p]][current_label_y[p]] = 0 + + return num_labels, stats, labels, excption + + def cal_distance(self, x1, y1, x2, y2): + return (x2 - x1) * (x2 - x1) + (y2 - y1) * (y2 - y1) + + def sort_points(self, endlat, endlon, lat, lon): + """sort_points""" + sorted_lat = [] + sorted_lon = [] + sorted_lat.append(endlat) + sorted_lon.append(endlon) + ori_len = len(lon) + while len(sorted_lon) != ori_len: + minpos = 1 + min_dist = self.cal_distance(lat[0], lon[0], lat[1], lon[1]) + for i in range(len(lat) - 1): + if self.cal_distance(lat[0], lon[0], lat[i + 1], lon[i + 1]) < min_dist: + min_dist = self.cal_distance(lat[0], lon[0], lat[i + 1], lon[i + 1]) + minpos = i + 1 + sorted_lat.append(lat[minpos]) + sorted_lon.append(lon[minpos]) + del lat[0] + del lon[0] + lat.insert(0, lat[minpos - 1]) + lon.insert(0, lon[minpos - 1]) + del lat[minpos] + del lon[minpos] + return sorted_lat, sorted_lon + + def visualization( + self, num_labels, stats, labels, file_lkf, fin_lkf_width_total, label + ): + """visualization""" + m, lon_ori, lat_ori = self._initialize_map_and_coordinates() + + all_save_lons, all_save_lats, all_save_width = self._process_all_labels( + num_labels, stats, labels, fin_lkf_width_total, lon_ori, lat_ori, m + ) + + self._save_results_and_visualize( + all_save_lons, all_save_lats, all_save_width, file_lkf, label, m + ) + + def _initialize_map_and_coordinates(self): + """initialize map_and coordinates""" + m = Basemap(projection="npstere", boundinglat=70, lon_0=0, resolution="l") + + lon_path = os.path.join(self.config["data"].get("data_path"), "LONC.bin") + lat_path = os.path.join(self.config["data"].get("data_path"), "LATC.bin") + data_shape = self.config["data"].get("data_shape") + lonc = readbin(lon_path, data_shape) + lon_ori = lonc[1000:3000, 1:2001] + latc = readbin(lat_path, data_shape) + lat_ori = latc[1000:3000, 1:2001] + + return m, lon_ori, lat_ori + + def _process_all_labels(self, num_labels, stats, labels, fin_lkf_width_total, lon_ori, lat_ori, m): + """process all labels""" + all_save_lons = [] + all_save_lats = [] + all_save_width = [] + + for i in range(num_labels - 1): + if stats[i + 1][4] <= 10: + continue + + lat, lon = self._process_single_label(i, stats, labels, lon_ori, lat_ori) + + if lat and lon: + all_save_lats.append(np.array(lat)) + all_save_lons.append(np.array(lon)) + all_save_width.append(fin_lkf_width_total[i]) + + xx, yy = m(np.array(lon), np.array(lat)) + plt.plot(xx, yy, ".", ms=1) + + return all_save_lons, all_save_lats, all_save_width + + def _process_single_label(self, label_index, stats, labels, lon_ori, lat_ori): + """process single label""" + lat = [] + lon = [] + endlat = 0 + endlon = 0 + + current_stat = stats[label_index + 1] + label_id = label_index + 1 + + for j in range(current_stat[2]): + for k in range(current_stat[3]): + row_idx = current_stat[1] + k + col_idx = current_stat[0] + j + + if labels[row_idx][col_idx] == label_id: + lat_point, lon_point, is_endpoint = self._process_label_point( + labels, row_idx, col_idx, lon_ori, lat_ori + ) + + if is_endpoint: + endlat = lat_point + endlon = lon_point + lat.insert(0, lat_point) + lon.insert(0, lon_point) + else: + lat.append(lat_point) + lon.append(lon_point) + + if lat and lon: + lat, lon = self.sort_points(endlat, endlon, lat, lon) + + return lat, lon + + def _process_label_point(self, labels, row_idx, col_idx, lon_ori, lat_ori): + """process label point""" + _, fis_endpoint, _ = self.calconnectivity(labels, row_idx, col_idx) + lon_point, lat_point = self.to_lonlat_bin(lon_ori, lat_ori, row_idx, col_idx) + is_endpoint = (fis_endpoint == 1) + + return lat_point, lon_point, is_endpoint + + def _save_results_and_visualize(self, all_save_lons, all_save_lats, all_save_width, + file_lkf, label, m): + """save results and visualize""" + vis = os.path.join(self.config["test"].get("output_path"), "vis") + os.makedirs(vis, exist_ok=True) + + if label == "mask": + self._save_mask_results(all_save_lons, all_save_lats, all_save_width, file_lkf) + vis_label = vis + "/mask_" + else: + self._save_prediction_results(all_save_lons, all_save_lats, all_save_width, file_lkf) + vis_label = vis + "/pred_" + + m.drawmapboundary() + m.drawcoastlines() + m.fillcontinents() + + plt.title("LKF detected") + plt.savefig(vis_label + str(file_lkf) + ".png", dpi=1200) + plt.clf() + + def _save_mask_results(self, all_save_lons, all_save_lats, all_save_width, file_lkf): + """save mask results""" + self.detect_result_label = os.path.join( + self.config["test"].get("output_path"), "detect_result_label" + ) + os.makedirs(self.detect_result_label, exist_ok=True) + + self.detect_result_label_width = os.path.join( + self.config["test"].get("output_path"), "detect_result_label_width" + ) + os.makedirs(self.detect_result_label_width, exist_ok=True) + + np.save( + self.detect_result_label + "/detect_result_" + str(file_lkf) + ".npy", + np.asarray([all_save_lons, all_save_lats], dtype=object), + ) + np.save( + self.detect_result_label_width + "/detect_result_width_" + str(file_lkf) + ".npy", + np.asarray([all_save_width], dtype=object), + ) + + def _save_prediction_results(self, all_save_lons, all_save_lats, all_save_width, file_lkf): + """save prediction results""" + self.detect_result = os.path.join( + self.config["test"].get("output_path"), "detect_result" + ) + os.makedirs(self.detect_result, exist_ok=True) + + self.detect_result_width = os.path.join( + self.config["test"].get("output_path"), "detect_result_width" + ) + os.makedirs(self.detect_result_width, exist_ok=True) + + np.save( + self.detect_result + "/detect_result_" + str(file_lkf) + ".npy", + np.asarray([all_save_lons, all_save_lats], dtype=object), + ) + np.save( + self.detect_result_width + "/detect_result_width_" + str(file_lkf) + ".npy", + np.asarray([all_save_width], dtype=object), + ) + + def evaluate_width_mask(self): + """evaluate_width_mask""" + kernel_size = 3 + padding = int((kernel_size - 1) / 2) + filename_root = self.output_dir_mask + for file in os.listdir(filename_root): + print(file) + plt.clf() + filename = os.path.join(filename_root, file) + gt = np.load(filename) + gt_pad_img = np.pad( + gt, ((2, 2), (2, 2)), "constant", constant_values=(np.nan, np.nan) + ) + gt_dect_result = np.zeros((gt_pad_img.shape[0], gt_pad_img.shape[1])) + for i in range(gt.shape[0]): + for j in range(gt.shape[1]): + gt_local_mat = gt_pad_img[ + i + padding - 6 : i + padding + 6, + j + padding - 6 : j + padding + 6, + ] + gt_local_sit = gt_local_mat.flatten() + gt_num_nan = len(gt_local_sit[np.isnan(gt_local_sit)]) + gt_local_sit[np.isnan(gt_local_sit)] = 0 + if len(gt_local_sit) > gt_num_nan: + gt_local_mean = np.sum(gt_local_sit) / ( + len(gt_local_sit) - gt_num_nan + ) + gt_local_std = np.std(gt_local_sit) + if gt_pad_img[i][j] < gt_local_mean - gt_local_std: + gt_dect_result[i][j] = 1 + gt_dect_result = gt_dect_result.astype(np.uint8) * 255 + _, _, gt_ori_stats, _ = ( + cv2.connectedComponentsWithStats(gt_dect_result) + ) + gt_dect_result = gt_dect_result / 255 + gt_skeleton0 = morphology.skeletonize(gt_dect_result) + gt_dect_result = gt_skeleton0.astype(np.uint8) * 255 + gt_num_labels, gt_labels, gt_stats, _ = ( + cv2.connectedComponentsWithStats(gt_dect_result) + ) + gt_lkf_width_total = [] + for ii in range(1, len(gt_ori_stats)): + gt_lkf_width = gt_ori_stats[ii][4] / gt_stats[ii][4] + gt_lkf_width_total.append(gt_lkf_width) + gt_exc = 1 + while gt_exc != 0: + gt_num_labels, gt_stats, gt_labels, gt_exc = self.breakup( + gt_num_labels, gt_stats, gt_labels + ) + print( + "========================================= break ==================================================" + ) + + self.visualization( + gt_num_labels, + gt_stats, + gt_labels, + file[:-4], + gt_lkf_width_total, + "mask", + ) + + def evaluate_width_pred(self): + """evaluate_width_pred""" + kernel_size = 3 + + padding = int((kernel_size - 1) / 2) + + filename_p_root = self.output_dir_pred + hccf_path = os.path.join(self.config["data"].get("data_path"), "hFacC.data") + landmask = readbin(hccf_path, self.config["data"].get("data_shape")) + landmask = landmask[1000:3000, 1:2001] + + for file in os.listdir(filename_p_root): + print(file) + plt.clf() + filename_p = os.path.join(filename_p_root, file) + img = np.load(filename_p) + img[landmask == 0] = 0 + img[img < 0] = 0 + pad_img = np.pad( + img, ((2, 2), (2, 2)), "constant", constant_values=(np.nan, np.nan) + ) + + dect_result = np.zeros((pad_img.shape[0], pad_img.shape[1])) + + for i in range(img.shape[0]): + for j in range(img.shape[1]): + local_mat = pad_img[ + i + padding - 6 : i + padding + 6, + j + padding - 6 : j + padding + 6, + ] + local_sit = local_mat.flatten() + num_nan = len(local_sit[np.isnan(local_sit)]) + local_sit[np.isnan(local_sit)] = 0 + if len(local_sit) > num_nan: + local_mean = np.sum(local_sit) / (len(local_sit) - num_nan) + local_std = np.std(local_sit) + if pad_img[i][j] < local_mean - local_std: + dect_result[i][j] = 1 + dect_result = dect_result.astype(np.uint8) * 255 + _, _, ori_stats, _ = ( + cv2.connectedComponentsWithStats(dect_result) + ) + dect_result = dect_result / 255 + + skeleton0 = morphology.skeletonize(dect_result) + dect_result = skeleton0.astype(np.uint8) * 255 + + num_labels, labels, stats, _ = cv2.connectedComponentsWithStats( + dect_result + ) + + lkf_width_total = [] + for ii in range(1, len(ori_stats)): + lkf_width = ori_stats[ii][4] / stats[ii][4] + lkf_width_total.append(lkf_width) + + exc = 1 + while exc != 0: + num_labels, stats, labels, exc = self.breakup(num_labels, stats, labels) + print( + "========================================= break ==================================================" + ) + self.visualization( + num_labels, stats, labels, file[:-4], lkf_width_total, "pred" + ) + + def calc_dis_v3(self, lkf_fcst, lkf_sar): + """calc_dis""" + dis_cutoff = 50.0 + imax = len(lkf_fcst[0]) + jmax = len(lkf_sar[0]) + dismin = np.nan * np.zeros((imax, jmax)) + disnrst = np.empty([]) + dis_index = [] + for j in tqdm(np.arange(jmax)): + lon2, lat2 = lkf_sar[0][j], lkf_sar[1][j] + lkf_sar_len = haversine.haversine((lat2[0], lon2[0]), (lat2[-1], lon2[-1])) + if lkf_sar_len <= dis_cutoff: + dismin[:, j] = np.nan + continue + for i in np.arange(imax): + lon1, lat1 = lkf_fcst[0][i], lkf_fcst[1][i] + lkf_fcst_len = haversine.haversine( + (lat1[0], lon1[0]), (lat1[-1], lon1[-1]) + ) + if lkf_fcst_len <= dis_cutoff: + dismin[i, j] = np.nan + continue + dismin[i, j] = self._calculate_min_distance(lat1, lon1, lat2, lon2) + disnrst = np.nanmin(dismin, axis=0) + print(np.nanmax(disnrst), np.nanmin(disnrst), np.nanmean(disnrst)) + for index in range(jmax): + dis_index_now = np.where(dismin[:, index] == disnrst[index]) + dis_index_now = np.squeeze(dis_index_now) + if dis_index_now.size == 0: + dis_index.append(999) + elif dis_index_now.size > 1: + dis_index.append(np.int32(dis_index_now[0])) + else: + dis_index.append(np.int32(dis_index_now)) + dis_index = np.array(dis_index) + + return disnrst, dis_index + + def _calculate_min_distance(self, lat1, lon1, lat2, lon2): + """calculate_min_distance""" + dis = np.zeros((len(lon1), len(lon2))) + for ii in np.arange(len(lon1)): + for jj in np.arange(len(lon2)): + dis[ii, jj] = haversine.haversine( + (lat1[ii], lon1[ii]), (lat2[jj], lon2[jj]) + ) + dis1 = np.nanmin(dis, axis=1) + dis2 = np.nanmin(dis, axis=0) + dis1 = np.sort(dis1) + dis2 = np.sort(dis2) + n = min([np.sum(~np.isnan(dis1)), np.sum(~np.isnan(dis2))]) + return (np.sum(dis1[:n]) + np.sum(dis2[:n])) / (2 * n) + + def visual(self, pred, label, pair, lkf_name): + """visual""" + m = Basemap(projection="npstere", boundinglat=70, lon_0=0, resolution="l") + plt.clf() + for i in range(len(pair)): + if pair[i] == 999: + continue + + pred_lon = pred[0][pair[i]] + pred_lat = pred[1][pair[i]] + + label_lon = label[0][i] + label_lat = label[1][i] + + xx, yy = m(np.array(pred_lon), np.array(pred_lat)) + lxx, lyy = m(np.array(label_lon), np.array(label_lat)) + + plt.subplot(121) + plt.plot(xx, yy, ".", ms=1) + m.drawmapboundary() + m.drawcoastlines() + m.fillcontinents() + plt.title("pred") + plt.subplot(122) + plt.plot(lxx, lyy, ".", ms=1) + m.drawmapboundary() + m.drawcoastlines() + m.fillcontinents() + plt.title("label") + output_path = self.config["test"].get("output_path") + out_path = f"{output_path.rstrip('/')}/{lkf_name}.png" + plt.savefig(out_path, dpi=600) + + def evaluate_dis(self, pred, label, pair): + """evaluate_dis""" + avg_diff = 0 + for i in range(len(pair)): + if pair[i] == 999: + continue + pred_lon = pred[0][pair[i]] + pred_lat = pred[1][pair[i]] + + label_lon = label[0][i] + label_lat = label[1][i] + + lkf_pred_len = haversine.haversine( + (pred_lat[0], pred_lon[0]), (pred_lat[-1], pred_lon[-1]) + ) + lkf_label_len = haversine.haversine( + (label_lat[0], label_lon[0]), (label_lat[-1], label_lon[-1]) + ) + + diff = abs(lkf_label_len - lkf_pred_len) / lkf_label_len + avg_diff = avg_diff + diff + + avg_diff = avg_diff / len(pair) + + return avg_diff + + def get_degree(self, lata, lona, latb, lonb): + """ + Args: + point p1(latA, lonA) + point p2(latB, lonB) + Returns: + bearing between the two GPS points, + default: the basis of heading direction is north + """ + radlata = radians(lata) + radlona = radians(lona) + radlatb = radians(latb) + radlonb = radians(lonb) + dlon = radlonb - radlona + y = sin(dlon) * cos(radlatb) + x = cos(radlata) * sin(radlatb) - sin(radlata) * cos(radlatb) * cos(dlon) + brng = degrees(atan2(y, x)) + brng = (brng + 360) % 360 + return brng + + def evaluate_degree(self, pred, label, pair): + """evaluate_degree""" + avg_diff = 0 + for i in range(len(pair)): + if pair[i] == 999: + continue + pred_lon = pred[0][pair[i]] + pred_lat = pred[1][pair[i]] + + label_lon = label[0][i] + label_lat = label[1][i] + + lkf_pred_degree = self.get_degree( + pred_lat[0], pred_lon[0], pred_lat[2], pred_lon[2] + ) + lkf_label_degree = self.get_degree( + label_lat[0], label_lon[0], label_lat[2], label_lon[2] + ) + + diff = abs(lkf_pred_degree - lkf_label_degree) + if diff >= 180: + diff = 360 - diff + avg_diff = avg_diff + diff + + avg_diff = avg_diff / len(pair) + + return avg_diff + + def evaluate_width(self, pred_width, label_width, pair): + """evaluate_width""" + avg_diff = 0 + for i in range(len(pair)): + if pair[i] == 999: + continue + + lkf_pred_width = pred_width[0][pair[i]] + lkf_label_width = label_width[0][i] + + diff = abs(lkf_label_width - lkf_pred_width) / lkf_label_width + avg_diff = avg_diff + diff + + avg_diff = avg_diff / len(pair) + + return avg_diff + + def evaluate_lead_all(self): + """evaluate_lead_all""" + fcst_root_path = self.detect_result + model_root_path = self.detect_result_label + fcst_width_path = self.detect_result_width + model_width_path = self.detect_result_label_width + lonc_path = os.path.join(self.config["data"].get("data_path"), "LONC.bin") + lonc = readbin(lonc_path, self.config["data"].get("data_shape")) + lonc = lonc[1000:3000, 1:2001] + avg_dis_diff = 0 + avg_degree_diff = 0 + avg_width_diff = 0 + max_dis_diff = 0 + max_degree_diff = 0 + max_width_diff = 0 + for lkf_fcst_file in os.listdir(fcst_root_path): + lkf_fcst_path = os.path.join(fcst_root_path, lkf_fcst_file) + lkf_model_path = os.path.join(model_root_path, lkf_fcst_file) + lkf_fcst_width_file = ( + "detect_result_width_ice_input_" + lkf_fcst_file.split("_")[4] + ) + lkf_fcst_width_path = os.path.join(fcst_width_path, lkf_fcst_width_file) + lkf_model_width_path = os.path.join(model_width_path, lkf_fcst_width_file) + lkf_fcst = np.load(lkf_fcst_path, allow_pickle=True) + lkf_model = np.load(lkf_model_path, allow_pickle=True) + lkf_fcst_width = np.load(lkf_fcst_width_path, allow_pickle=True) + lkf_model_width = np.load(lkf_model_width_path, allow_pickle=True) + _, pair = self.calc_dis_v3(lkf_fcst, lkf_model) + self.visual(lkf_fcst, lkf_model, pair, lkf_fcst_file[:-4]) + dis_width = self.evaluate_width(lkf_fcst_width, lkf_model_width, pair) + print(lkf_fcst_file, "dis width is: ", dis_width) + if dis_width > max_width_diff: + max_width_diff = dis_width + avg_width_diff = avg_width_diff + dis_width + dis_diff = self.evaluate_dis(lkf_fcst, lkf_model, pair) + print(lkf_fcst_file, "dis diff is: ", dis_diff) + if dis_diff > max_dis_diff: + max_dis_diff = dis_diff + avg_dis_diff = avg_dis_diff + dis_diff + degree_diff = self.evaluate_degree(lkf_fcst, lkf_model, pair) + if degree_diff > max_degree_diff: + max_degree_diff = degree_diff + print(lkf_fcst_file, "degree diff is: ", degree_diff) + avg_degree_diff = avg_degree_diff + degree_diff + + print("avg diff width: ", avg_width_diff / len(os.listdir(fcst_root_path))) + print("avg diff dis: ", avg_dis_diff / len(os.listdir(fcst_root_path))) + print("avg diff degree: ", avg_degree_diff / len(os.listdir(fcst_root_path))) + + print("max diff width: ", max_width_diff) + print("max diff dis: ", max_dis_diff) + print("max diff degree: ", max_degree_diff) + + def lonlat2xy2km(self, l_lon, lon): + x = int(np.argwhere(l_lon == lon)[0][0]) + y = int(np.argwhere(l_lon == lon)[0][1]) + return x, y + + def evaluate_acc(self): + """evaluate_acc""" + fcst_root_path = self.detect_result + model_root_path = self.detect_result_label + lon_path = os.path.join(self.config["data"].get("data_path"), "LONC.bin") + lon = readbin(lon_path, self.config["data"].get("data_shape")) + lon = lon[1000:3000, 1:2001] + + avg_acc = 0 + + for lkf_fcst_file in os.listdir(fcst_root_path): + lkf_fcst_path = os.path.join(fcst_root_path, lkf_fcst_file) + lkf_model_path = os.path.join(model_root_path, lkf_fcst_file) + lkf_fcst = np.load(lkf_fcst_path, allow_pickle=True) + lkf_model = np.load(lkf_model_path, allow_pickle=True) + + pred = np.zeros((2000, 2000)) + label = np.zeros((2000, 2000)) + for i in range(lkf_fcst.shape[1]): + for j in range(len(lkf_fcst[0][i])): + if lkf_fcst[0][i][j] == 0: + continue + px, py = self.lonlat2xy2km(lon, lkf_fcst[0][i][j]) + pred[px][py] = 255 + + for ii in range(lkf_model.shape[1]): + for jj in range(len(lkf_model[0][ii])): + if lkf_model[0][ii][jj] == 0: + continue + lx, ly = self.lonlat2xy2km(lon, lkf_model[0][ii][jj]) + label[lx][ly] = 255 + + acc_num = 0 + for w in range(label.shape[0]): + for h in range(label.shape[1]): + if pred[w][h] == label[w][h]: + acc_num = acc_num + 1 + acc = acc_num / (label.shape[0] * label.shape[1]) + print("acc for ", lkf_fcst_file, " is:", acc) + avg_acc = avg_acc + acc + + avg_acc = avg_acc / (len(os.listdir(fcst_root_path))) + print("avg acc is: ", avg_acc) diff --git a/MindEarth/applications/sea/LeadFormer/src/loss.py b/MindEarth/applications/sea/LeadFormer/src/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..40fca1736fb12c70093a8b4af57d189e39fb8288 --- /dev/null +++ b/MindEarth/applications/sea/LeadFormer/src/loss.py @@ -0,0 +1,263 @@ +# right 2020-2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""loss""" +import numpy as np + +import mindspore +import mindspore.nn as nn +import mindspore.common.dtype as mstype +import mindspore.ops.operations as F +from mindspore.ops import functional as F2 +from mindspore.nn.cell import Cell +from mindspore import ops, Tensor + +from .utils import warp, make_grid + + +class MyLoss(Cell): + """ + Base class for other losses. + """ + def __init__(self, reduction="mean"): + super().__init__() + if reduction is None: + reduction = "none" + + if reduction not in ("mean", "sum", "none"): + raise ValueError( + f"reduction method for {reduction.lower()} is not supported" + ) + + self.average = True + self.reduce = True + if reduction == "sum": + self.average = False + if reduction == "none": + self.reduce = False + + self.reduce_mean = F.ReduceMean() + self.reduce_sum = F.ReduceSum() + self.mul = F.Mul() + self.cast = F.Cast() + + def get_axis(self, x): + shape = F2.shape(x) + length = F2.tuple_len(shape) + perm = F2.make_range(0, length) + return perm + + def get_loss(self, x, weights=1.0): + """ + Computes the weighted loss + Args: + weights: Optional `Tensor` whose rank is either 0, or the same rank as inputs, and must be broadcastable to + inputs (i.e., all dimensions must be either `1`, or the same as the corresponding inputs dimension). + """ + input_dtype = x.dtype + x = self.cast(x, mstype.float32) + weights = self.cast(weights, mstype.float32) + x = self.mul(weights, x) + if self.reduce: + axis = self.get_axis(x) + if self.average: + x = self.reduce_mean(x, axis) + else: + x = self.reduce_sum(x, axis) + x = self.cast(x, input_dtype) + return x + + def construct(self, logits, label): + """unused""" + raise NotImplementedError + + +class CrossEntropyWithLogits(MyLoss): + """CrossEntropyWithLogits""" + def __init__(self): + super().__init__() + self.transpose_fn = F.Transpose() + self.reshape_fn = F.Reshape() + self.softmax_cross_entropy_loss = nn.SoftmaxCrossEntropyWithLogits() + self.cast = F.Cast() + + def construct(self, logits, label): + """construct""" + logits = self.transpose_fn(logits, (0, 2, 3, 1)) + logits = self.cast(logits, mindspore.float32) + label = self.transpose_fn(label, (0, 2, 3, 1)) + _, _, _, c = F.Shape()(label) + + loss = self.reduce_mean( + self.softmax_cross_entropy_loss( + self.reshape_fn(logits, (-1, c)), self.reshape_fn(label, (-1, c)) + ) + ) + return self.get_loss(loss) + + +class MultiCrossEntropyWithLogits(nn.Cell): + """MultiCrossEntropyWithLogits""" + def __init__(self): + super().__init__() + self.loss = CrossEntropyWithLogits() + self.squeeze = F.Squeeze(axis=0) + + def construct(self, logits, label): + """construct""" + total_loss = 0 + for i in range(len(logits)): + total_loss += self.loss(self.squeeze(logits[i : i + 1]), label) + return total_loss + + +class MSELoss(MyLoss): + """MSEloss""" + def __init__(self): + super().__init__() + self.transpose_fn = F.Transpose() + self.reshape_fn = F.Reshape() + self.mse_loss = nn.MSELoss() + self.dice_loss = nn.DiceLoss(smooth=1e-5) + self.cast = F.Cast() + self.mae_loss = nn.MAELoss(reduction="mean") + self.rmse_loss = nn.RMSELoss() + + def construct(self, logits, label): + """construct""" + print(logits.shape, label.shape) + logits = self.transpose_fn(logits, (0, 2, 3, 1)) + logits = self.cast(logits, mindspore.float32) + label = self.transpose_fn(label, (0, 2, 3, 1)) + label = label[:, :, :, 1] + label = label.unsqueeze(dim=3) + print(logits.shape, label.shape) + _, _, _, c = F.Shape()(label) + + rmse_loss = self.rmse_loss( + self.reshape_fn(logits, (-1, c)), self.reshape_fn(label, (-1, c)) + ) + return self.get_loss(rmse_loss) + + +class MultiMSELoss(nn.Cell): + """MultiMSELoss""" + def __init__(self): + super().__init__() + self.loss = MSELoss() + self.squeeze = F.Squeeze(axis=0) + + def construct(self, logits, label): + print(logits.shape, label.shape) + total_loss = 0 + for i in range(len(logits)): + total_loss += self.loss(self.squeeze(logits[i : i + 1]), label) + return total_loss + + +class WeightDistance(nn.Cell): + """Weighted L1 distance""" + + def construct(self, true_frame, pred_frame): + loss = ops.mean(ops.abs(true_frame - pred_frame)) + return loss + + +class MotionLossNet(nn.Cell): + """Motion regularization""" + def __init__(self, in_channels=3, out_channels=3, kernel_size=3): + super().__init__() + kernel_v = np.array([[1, 0, -1], [2, 0, -2], [1, 0, -1]]).astype(np.float32) + kernel_h = np.array([[1, 2, 1], [0, 0, 0], [-1, -2, -1]]).astype(np.float32) + self.weight_vt1 = self.get_kernels(kernel_v, out_channels) + self.weight_vt2 = self.get_kernels(kernel_h, out_channels) + self.conv_2d_v = nn.Conv2d( + in_channels, out_channels, kernel_size=kernel_size, has_bias=False + ) + self.conv_2d_v.weight.set_data(Tensor(self.weight_vt1, mindspore.float32)) + for w in self.conv_2d_v.trainable_params(): + w.requires_grad = False + self.conv_2d_h = nn.Conv2d( + in_channels, out_channels, kernel_size=kernel_size, has_bias=False + ) + self.conv_2d_h.weight.set_data(Tensor(self.weight_vt2, mindspore.float32)) + for w in self.conv_2d_h.trainable_params(): + w.requires_grad = False + + @staticmethod + def get_kernels(kernel, repeats): + kernel = np.expand_dims(kernel, axis=(0, 1)) + kernels = [kernel] * repeats + kernels = np.concatenate(kernels, axis=0) + return kernels + + def calc_diff_v(self, image): + diff_v1 = self.conv_2d_v(image) + diff_v2 = self.conv_2d_h(image) + lambda_v = diff_v1**2 + diff_v2**2 + loss = ops.sum(lambda_v) + return loss + + def custom_2d_conv_sobel(self, image): + motion_loss1 = self.calc_diff_v(image) + motion_loss2 = self.calc_diff_v(image) + loss = (motion_loss1 + motion_loss2) / ( + image.shape[0] * image.shape[-1] * image.shape[-2] + ) + return loss + + def construct(self, motion): + loss1 = self.custom_2d_conv_sobel(motion[:, :1]) + loss2 = self.custom_2d_conv_sobel(motion[:, 1:]) + return loss1 + loss2 + + +class EvolutionLoss(nn.Cell): + """Evolution loss definition""" + def __init__(self, model): + super().__init__() + self.model = model + self.loss_fn_accum = WeightDistance() + self.loss_fn_motion = MotionLossNet( + in_channels=1, out_channels=1, kernel_size=3 + ) + sample_tensor = np.zeros((1, 1, 2000, 2000)).astype(np.float32) + self.grid = Tensor(make_grid(sample_tensor), mindspore.float32) + self.lamb = float(1e-2) + + def construct(self, logits, labels): + """construct""" + intensity = logits[0] + motion = logits[1] + batch, _, height, width = logits[0].shape + motion_ = motion.reshape(batch, 1, 2, height, width) + intensity_ = intensity.reshape(batch, 1, 1, height, width) + accum = 0 + last_frame = labels[:, 0, :, :] + last_frame = ops.unsqueeze(last_frame, dim=1) + grid = self.grid.tile((batch, 1, 1, 1)) + next_frame = labels[:, 1, :, :] + next_frame = ops.unsqueeze(next_frame, dim=1) + xt_1 = warp( + last_frame, motion_[:, 0], grid, mode="bilinear", padding_mode="border" + ) + accum += self.loss_fn_accum(next_frame, xt_1) + last_frame = warp( + last_frame, motion_[:, 0], grid, mode="nearest", padding_mode="border" + ) + last_frame = last_frame + intensity_[:, 0] + accum += self.loss_fn_accum(next_frame, last_frame) + motion = self.loss_fn_motion(motion_[:, 0]) + loss = accum + self.lamb * motion + return loss diff --git a/MindEarth/applications/sea/LeadFormer/src/myop.py b/MindEarth/applications/sea/LeadFormer/src/myop.py new file mode 100644 index 0000000000000000000000000000000000000000..1bcaff44f9f02d022d926f80f75170854944608a --- /dev/null +++ b/MindEarth/applications/sea/LeadFormer/src/myop.py @@ -0,0 +1,168 @@ +# right 2020-2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""adam""" +from __future__ import absolute_import, division +import numpy as np + +from mindspore import context +from mindspore.common import dtype as mstype +from mindspore.common.api import jit +from mindspore.ops import operations as P +from mindspore.ops import composite as C +from mindspore.ops import functional as F +from mindspore.common.tensor import Tensor +from mindspore import _checkparam as validator +from mindspore.nn.optim.optimizer import Optimizer + + +_adam_opt = C.MultitypeFuncGraph("adam_opt") +_fused_adam_weight_decay = C.MultitypeFuncGraph("fused_adam_weight_decay") +_lazy_adam_opt = C.MultitypeFuncGraph("lazy_adam_opt") +_scaler_one = Tensor(1, mstype.int32) +_scaler_ten = Tensor(10, mstype.float32) + + +@_adam_opt.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", + "Tensor", "Bool", "Bool") +def _update_run_op(beta1, beta2, eps, lr, weight_decay, param, m, v, gradient, decay_flag, optim_filter): + """Update parameters.""" + op_cast = P.Cast() + if optim_filter: + op_mul = P.Mul() + op_square = P.Square() + op_sqrt = P.Sqrt() + op_cast = P.Cast() + op_reshape = P.Reshape() + op_shape = P.Shape() + param_fp32 = op_cast(param, mstype.float32) + m_fp32 = op_cast(m, mstype.float32) + v_fp32 = op_cast(v, mstype.float32) + gradient_fp32 = op_cast(gradient, mstype.float32) + + next_m = op_mul(beta1, m_fp32) + op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32) + - beta1, gradient_fp32) + + next_v = op_mul(beta2, v_fp32) + op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32) + - beta2, op_square(gradient_fp32)) + + update = next_m / (eps + op_sqrt(next_v)) + if decay_flag: + update = op_mul(weight_decay, param_fp32) + update + + update_with_lr = op_mul(lr, update) + next_param = param_fp32 - op_reshape(update_with_lr, op_shape(param_fp32)) + + next_param = F.depend(next_param, F.assign(param, op_cast(next_param, F.dtype(param)))) + next_param = F.depend(next_param, F.assign(m, op_cast(next_m, F.dtype(m)))) + next_param = F.depend(next_param, F.assign(v, op_cast(next_v, F.dtype(v)))) + + return op_cast(next_param, F.dtype(param)) + return op_cast(gradient, F.dtype(param)) + +def _check_param_value(beta1, beta2, eps, prim_name): + """Check the type of inputs.""" + validator.check_value_type("beta1", beta1, [float], prim_name) + validator.check_value_type("beta2", beta2, [float], prim_name) + validator.check_value_type("eps", eps, [float], prim_name) + validator.check_float_range(beta1, 0.0, 1.0, validator.INC_NEITHER, "beta1", prim_name) + validator.check_float_range(beta2, 0.0, 1.0, validator.INC_NEITHER, "beta2", prim_name) + validator.check_positive_float(eps, "eps", prim_name) + + +class AdamWeightDecay(Optimizer): + """ + Adam optimizer with weight decay. + + This implementation adds weight decay to the standard Adam algorithm. + It supports parallel optimization and can run on different device targets. + Args: + params: Network parameters to optimize + learning_rate: Learning rate (default: 1e-3) + beta1: Exponential decay rate for first moment estimates (default: 0.9) + beta2: Exponential decay rate for second moment estimates (default: 0.999) + eps: Term added to denominator for numerical stability (default: 1e-6) + weight_decay: Weight decay coefficient (default: 0.0) + """ + _support_parallel_optimizer = True + + def __init__(self, params, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-6, weight_decay=0.0): + super().__init__(learning_rate, params, weight_decay) + _check_param_value(beta1, beta2, eps, self.cls_name) + self.beta1 = Tensor(np.array([beta1]).astype(np.float32)) + self.beta2 = Tensor(np.array([beta2]).astype(np.float32)) + self.eps = Tensor(np.array([eps]).astype(np.float32)) + self.moments1 = self._parameters.clone(prefix="adam_m", init='zeros') + self.moments2 = self._parameters.clone(prefix="adam_v", init='zeros') + self.fused_opt = P.AdamWeightDecay() + if context.get_context("device_target") == "Ascend": + self.use_fused_opt = False + else: + self.use_fused_opt = True + + @jit + def construct(self, gradients): + """construct""" + gradients = self.flatten_gradients(gradients) + weight_decay = self.get_weight_decay() + lr = self.get_lr() + if self.use_fused_opt: + if self.is_group: + if self.is_group_lr: + optim_result = self.hyper_map( + F.partial(_fused_adam_weight_decay, self.fused_opt, self.beta1, self.beta2, self.eps), + lr, weight_decay, self._parameters, self.moments1, + self.moments2, gradients, self.decay_flags, self.optim_filter) + else: + optim_result = self.hyper_map( + F.partial(_fused_adam_weight_decay, self.fused_opt, self.beta1, self.beta2, self.eps, lr), + weight_decay, self._parameters, self.moments1, self.moments2, + gradients, self.decay_flags, self.optim_filter) + else: + optim_result = self.hyper_map( + F.partial(_fused_adam_weight_decay, self.fused_opt, self.beta1, self.beta2, self.eps, lr, + weight_decay), + self._parameters, self.moments1, self.moments2, + gradients, self.decay_flags, self.optim_filter) + else: + if self.is_group: + if self.is_group_lr: + optim_result = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps), + lr, weight_decay, self._parameters, self.moments1, + self.moments2, gradients, self.decay_flags, self.optim_filter) + else: + optim_result = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps, lr), + weight_decay, self._parameters, self.moments1, self.moments2, + gradients, self.decay_flags, self.optim_filter) + else: + optim_result = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps, lr, weight_decay), + self._parameters, self.moments1, self.moments2, + gradients, self.decay_flags, self.optim_filter) + if self.use_parallel: + self.broadcast_params(optim_result) + + return optim_result + + @Optimizer.target.setter + def target(self, value): + """ + If the input value is set to "CPU", the parameters will be updated on the host using the Fused + optimizer operation. + """ + self._set_base_target(value) + if value == 'CPU': + self.fused_opt.set_device("CPU") + self.use_fused_opt = True + else: + self.use_fused_opt = False diff --git a/MindEarth/applications/sea/LeadFormer/src/segformer.py b/MindEarth/applications/sea/LeadFormer/src/segformer.py new file mode 100644 index 0000000000000000000000000000000000000000..f677c044e06e8a35bef560d57155b8b715999576 --- /dev/null +++ b/MindEarth/applications/sea/LeadFormer/src/segformer.py @@ -0,0 +1,199 @@ +# right 2020-2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""net""" +from functools import partial + +import mindspore.nn as nn +import mindspore.ops as ops + +from .backbone import MixVisionTransformer, init_weights + + +class MLP(nn.Cell): + """ + Linear Embedding + """ + def __init__(self, input_dim=2048, embed_dim=768): + super().__init__() + self.proj = nn.Dense(input_dim, embed_dim) + self.apply(init_weights) + + def construct(self, x): + x = x.flatten(start_dim=2).swapaxes(1, 2) + x = self.proj(x) + return x + + +class ConvModule(nn.Cell): + """ + A convolutional module including convolution, batch normalization, and activation. + + Args: + c1 (int): Number of input channels + c2 (int): Number of output channels + k (int): Kernel size, default=1 + s (int): Stride, default=1 + p (int): Padding size, default=0 + g (int): Number of groups for grouped convolution, default=1 + act (Union[bool, nn.Cell]): Activation function. True for ReLU, False/None for Identity, + or directly provide a activation module + """ + def __init__(self, c1, c2, k=1, s=1, p=0, g=1, act=True): + super().__init__() + self.conv = nn.Conv2d( + c1, + c2, + kernel_size=k, + stride=s, + pad_mode="pad", + padding=p, + has_bias=False, + group=g, + ) + self.bn = nn.BatchNorm2d(c2, eps=0.001, momentum=0.97) + self.act = ( + nn.ReLU() + if act is True + else (act if isinstance(act, nn.Cell) else nn.Identity()) + ) + + def construct(self, x): + """Standard forward pass: Conv -> BN -> Activation""" + return self.act(self.bn(self.conv(x))) + + def fuseforward(self, x): + return self.act(self.conv(x)) + + +class SegFormerHead(nn.Cell): + """ + SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers + """ + def __init__( + self, + num_classes=20, + in_channels=None, + embedding_dim=768, + dropout_ratio=0.1, + ): + super().__init__() + in_channels = in_channels if in_channels is not None else [64, 128, 320, 512] + c1_in_channels, c2_in_channels, c3_in_channels, c4_in_channels = in_channels + + self.linear_c4 = MLP(input_dim=c4_in_channels, embed_dim=embedding_dim) + self.linear_c3 = MLP(input_dim=c3_in_channels, embed_dim=embedding_dim) + self.linear_c2 = MLP(input_dim=c2_in_channels, embed_dim=embedding_dim) + self.linear_c1 = MLP(input_dim=c1_in_channels, embed_dim=embedding_dim) + + self.linear_fuse = ConvModule( + c1=embedding_dim * 4, + c2=embedding_dim, + k=1, + ) + + self.linear_pred = nn.Conv2d( + embedding_dim, num_classes, kernel_size=1, has_bias=True + ) + self.dropout = nn.Dropout2d(dropout_ratio) + self.linear_v4 = MLP(input_dim=c4_in_channels, embed_dim=embedding_dim) + self.linear_v3 = MLP(input_dim=c3_in_channels, embed_dim=embedding_dim) + self.linear_v2 = MLP(input_dim=c2_in_channels, embed_dim=embedding_dim) + self.linear_v1 = MLP(input_dim=c1_in_channels, embed_dim=embedding_dim) + + self.linear_fuse_v = ConvModule( + c1=embedding_dim * 4, + c2=embedding_dim, + k=1, + ) + + self.linear_pred_v = nn.Conv2d(embedding_dim, 2, kernel_size=1, has_bias=True) + self.dropout_v = nn.Dropout2d(dropout_ratio) + + def _upsample(self, inputs, linears, fuse, pred, dropout): + """upsample""" + c1, c2, c3, c4 = inputs + b = c4.shape[0] + proj4 = linears[3](c4).permute(0, 2, 1).reshape(b, -1, c4.shape[2], c4.shape[3]) + proj3 = linears[2](c3).permute(0, 2, 1).reshape(b, -1, c3.shape[2], c3.shape[3]) + proj2 = linears[1](c2).permute(0, 2, 1).reshape(b, -1, c2.shape[2], c2.shape[3]) + proj1 = linears[0](c1).permute(0, 2, 1).reshape(b, -1, c1.shape[2], c1.shape[3]) + + up4 = ops.interpolate(proj4, size=c1.shape[2:], mode="bilinear", align_corners=False) + up3 = ops.interpolate(proj3, size=c1.shape[2:], mode="bilinear", align_corners=False) + up2 = ops.interpolate(proj2, size=c1.shape[2:], mode="bilinear", align_corners=False) + + fused = fuse(ops.cat((up4, up3, up2, proj1), axis=1)) + out = pred(dropout(fused)) + return out + + def construct(self, inputs): + """construct""" + context_output = self._upsample( + inputs, + [self.linear_c1, self.linear_c2, self.linear_c3, self.linear_c4], + self.linear_fuse, + self.linear_pred, + self.dropout, + ) + vision_output = self._upsample( + inputs, + [self.linear_v1, self.linear_v2, self.linear_v3, self.linear_v4], + self.linear_fuse_v, + self.linear_pred_v, + self.dropout_v, + ) + return context_output, vision_output + + +class SegFormer(nn.Cell): + """Simple and Efficient Semantic Segmentation Framework""" + def __init__( + self, + in_channels=None, + num_classes=21, + embedding_dim=256, + num_heads=None, + mlp_ratios=None, + depths=None, + sr_ratios=None, + ): + self.in_channels = in_channels if in_channels is not None else [64, 128, 320, 512] + self.num_heads = num_heads if num_heads is not None else [1, 2, 5, 8] + self.mlp_ratios = mlp_ratios if mlp_ratios is not None else [4, 4, 4, 4] + self.depths = depths if depths is not None else [2, 2, 2, 2] + self.sr_ratios = sr_ratios if sr_ratios is not None else [8, 4, 2, 1] + super().__init__() + self.backbone = MixVisionTransformer( + embed_dims=self.in_channels, + num_heads=self.num_heads, + mlp_ratios=self.mlp_ratios, + qkv_bias=True, + norm_layer=partial(nn.LayerNorm, epsilon=1e-6), + depths=self.depths, + sr_ratios=self.sr_ratios, + drop_rate=0.0, + drop_path_rate=0.1, + ) + self.decode_head = SegFormerHead(num_classes, self.in_channels, embedding_dim) + + def construct(self, inputs): + h, w = inputs.shape[2], inputs.shape[3] + + x = self.backbone(inputs) + x, v = self.decode_head(x) + + x = ops.interpolate(x, size=(h, w), mode="bilinear", align_corners=True) + v = ops.interpolate(v, size=(h, w), mode="bilinear", align_corners=True) + return x, v diff --git a/MindEarth/applications/sea/LeadFormer/src/solver.py b/MindEarth/applications/sea/LeadFormer/src/solver.py new file mode 100644 index 0000000000000000000000000000000000000000..1af8bb3bc77dcbe249be35572bb6eae3450d9cf5 --- /dev/null +++ b/MindEarth/applications/sea/LeadFormer/src/solver.py @@ -0,0 +1,193 @@ +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""train""" +import os + +import numpy as np + +import mindspore +import mindspore.nn as nn +from mindspore import Model, context +from mindspore.communication.management import init, get_rank, get_group_size +from mindspore.train.callback import CheckpointConfig, ModelCheckpoint +from mindspore.context import ParallelMode +from mindspore.train.serialization import load_checkpoint, load_param_into_net + +from .segformer import SegFormer +from .data_loader import create_multi_class_dataset +from .loss import EvolutionLoss +from .utils import StepLossTimeMonitor, filter_checkpoint_parameter_by_list +from .myop import AdamWeightDecay + + +def learning_rate_function(lr_all, cur_step_num): + return lr_all[cur_step_num] + + +class Trainer(nn.Cell): + """Neural network trainer class that encapsulates the complete training pipeline + + This class is responsible for: + - Initializing training configurations and model + - Handling distributed training setup + - Managing dataset loading and preprocessing + - Configuring optimizer and learning rate schedule + - Executing training loop and saving checkpoints + + Attributes: + epochs (int): Total number of training epochs + rank (int): Process rank in distributed training + group_size (int): Total number of processes in distributed training + data_dir (str): Path to training data directory + run_distribute (bool): Whether to enable distributed training + model_name (str): Name of the model being used + net (nn.Cell): Neural network model instance + resume (bool): Whether to resume training from checkpoint + resume_ckpt (str): Path to checkpoint for resuming training + transfer_training (bool): Whether to perform transfer learning + filter_weight (list): List of parameter names to filter during transfer learning + repeat (int): Number of dataset repetitions + split (float): Train/validation split ratio + num_classes (int): Number of classes in classification task + train_augment (bool): Whether to enable training data augmentation + output_path (str): Path to save training outputs + keep_checkpoint_max (int): Maximum number of checkpoints to keep + weight_decay (float): Weight decay coefficient for optimizer + batch_size (int): Training batch size + lr (float): Initial learning rate + amp_level (str): Auto mixed precision level (O0/O1/O2/O3) + """ + def __init__(self, config, epochs=400): + super().__init__() + self.epochs = epochs + self.rank = 0 + self.group_size = 1 + self.data_dir = config["data"].get("data_path", "") + self.run_distribute = config["train"].get("run_distribute", False) + self.model_name = config["model"].get("name", "") + self.in_channels = config["model"].get("in_channels", [64, 128, 320, 512]) + if self.model_name == "ice_simple": + self.net = SegFormer( + in_channels=self.in_channels, + num_classes=1, + embedding_dim=256, + ) + else: + raise ValueError("Unsupported model: {}".format(self.model_name)) + self.resume = config["train"].get("resume", False) + self.resume_ckpt = config["train"].get("resume_ckpt", "./") + self.transfer_training = config["train"].get("transfer_training", False) + self.filter_weight = config["model"].get("filter_weight", []) + self.repeat = config["train"].get("repeat", 1) + self.split = config["data"].get("split", 0.98) + self.num_classes = config["model"].get("num_classes", 1) + self.train_augment = config["data"].get("train_augment", False) + self.output_path = config["summary"].get("output_path", "./train") + self.keep_checkpoint_max = config["summary"].get("keep_checkpoint_max", 1) + self.weight_decay = config["optimizer"].get("weight_decay", 0.01) + self.batch_size = config["data"].get("batch_size", 1) + self.lr = config["optimizer"].get("lr", 0.0001) + self.amp_level = config["train"].get("amp_level", "O3") + + def train(self): + """train""" + if self.run_distribute: + init() + self.group_size = get_group_size() + self.rank = get_rank() + parallel_mode = ParallelMode.DATA_PARALLEL + context.set_auto_parallel_context( + parallel_mode=parallel_mode, + device_num=self.group_size, + gradients_mean=False, + ) + + if self.resume: + param_dict = load_checkpoint(self.resume_ckpt) + if self.transfer_training: + filter_checkpoint_parameter_by_list(param_dict, self.filter_weight) + load_param_into_net(self.net, param_dict) + + dataset_sink_mode = False + per_print_times = 1 + train_dataset = create_multi_class_dataset( + self.data_dir, + self.repeat, + self.batch_size, + is_train=True, + split=self.split, + rank=self.rank, + group_size=self.group_size, + shuffle=True, + ) + train_data_size = train_dataset.get_dataset_size() + ckpt_save_dir = os.path.join(self.output_path, f"ckpt_{self.rank}") + save_ck_steps = train_data_size + ckpt_config = CheckpointConfig( + save_checkpoint_steps=save_ck_steps, + keep_checkpoint_max=self.keep_checkpoint_max, + ) + ckpoint_cb = ModelCheckpoint( + prefix="ckpt_{}_adam".format(self.model_name), + directory=ckpt_save_dir, + config=ckpt_config, + ) + + end_learning_rate = 0.00 + step_per_epoch = train_data_size + total_step = int(step_per_epoch * self.epochs / self.repeat) + 1 + decay_epoch = int(self.epochs / self.repeat) + exponential_decay_lr = np.array( + nn.cosine_decay_lr( + end_learning_rate, self.lr, total_step, step_per_epoch, decay_epoch + ) + ) + optimizer = AdamWeightDecay( + params=self.net.trainable_params(), + learning_rate=self.lr, + beta1=0.9, + beta2=0.999, + weight_decay=self.weight_decay, + ) + loss_scale = mindspore.train.loss_scale_manager.FixedLossScaleManager( + loss_scale=2048 + ) + criterion = EvolutionLoss(self.net) + model = Model( + self.net, + loss_fn=criterion, + loss_scale_manager=loss_scale, + optimizer=optimizer, + amp_level=self.amp_level, + ) + print("============== Starting Training ==============") + callbacks = [ + StepLossTimeMonitor( + lr_all=exponential_decay_lr, + learning_rate_func=learning_rate_function, + batch_size=self.batch_size, + per_print_times=per_print_times, + rank=self.rank, + ), + ckpoint_cb, + ] + print("==============================================================") + model.train( + int(self.epochs / self.repeat), + train_dataset, + callbacks=callbacks, + dataset_sink_mode=dataset_sink_mode, + ) + print("============== End Training ==============") diff --git a/MindEarth/applications/sea/LeadFormer/src/utils.py b/MindEarth/applications/sea/LeadFormer/src/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..16f2a87da0c10083b2564260b9983f192abab98e --- /dev/null +++ b/MindEarth/applications/sea/LeadFormer/src/utils.py @@ -0,0 +1,219 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""util""" +import os +import time +import struct +import sys + +import numpy as np + +import mindspore +import mindspore.ops as ops +from mindspore.train.callback import Callback +from mindspore.common.tensor import Tensor + + +class StepLossTimeMonitor(Callback): + """Callback for monitoring training progress, loss, and performance metrics + + This callback tracks and logs: + - Training loss at each step + - Current learning rate + - Processing speed (FPS) + - Epoch-level metrics + + Attributes: + _per_print_times (int): Frequency of logging (every n steps) + batch_size (int): Training batch size + rank (int): Process rank in distributed training + step_time (float): Timestamp for step timing + epoch_start (float): Timestamp for epoch timing + losses (list): List to store loss values per epoch + learning_rate_func (function): Function to compute learning rate + lr_all (np.array): Array containing all learning rate values + """ + def __init__( + self, lr_all, learning_rate_func, batch_size, per_print_times=1, rank=0 + ): + super().__init__() + if not isinstance(per_print_times, int) or per_print_times < 0: + raise ValueError("print_step must be int and >= 0.") + self._per_print_times = per_print_times + self.batch_size = batch_size + self.rank = rank + self.step_time = 0 + self.epoch_start = 0 + self.losses = [] + self.learning_rate_func = learning_rate_func + self.lr_all = lr_all + + def step_begin(self, run_context): + """step_begin""" + self.step_time = time.time() + self.run_context = run_context + + def step_end(self, run_context): + """step_end""" + step_seconds = time.time() - self.step_time + step_fps = self.batch_size * 1.0 / step_seconds + + cb_params = run_context.original_args() + loss = cb_params.net_outputs + arr_lr = cb_params.optimizer.learning_rate.asnumpy() + new_lr = self.learning_rate_func(self.lr_all, cb_params.cur_step_num) + ops.assign(cb_params.optimizer.learning_rate, Tensor(new_lr, mindspore.float32)) + + if isinstance(loss, (tuple, list)): + if isinstance(loss[0], Tensor) and isinstance( + loss[0].asnumpy(), np.ndarray + ): + loss = loss[0] + + if isinstance(loss, Tensor) and isinstance(loss.asnumpy(), np.ndarray): + loss = np.mean(loss.asnumpy()) + + cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1 + + if isinstance(loss, float) and (np.isnan(loss) or np.isinf(loss)): + raise ValueError( + "epoch: {} step: {}. Invalid loss, terminating training.".format( + cb_params.cur_epoch_num, cur_step_in_epoch + ) + ) + self.losses.append(loss) + if self._per_print_times != 0: + print( + "step: %s, loss is %s, fps is %s, lr is %s" + % (cur_step_in_epoch, loss, step_fps, arr_lr), + flush=True, + ) + + def epoch_begin(self, run_context): + self.epoch_start = time.time() + self.losses = [] + self.run_context = run_context + + def epoch_end(self, run_context): + cb_params = run_context.original_args() + epoch_cost = time.time() - self.epoch_start + step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1 + step_fps = self.batch_size * 1.0 * step_in_epoch / epoch_cost + if self.rank == 0: + print( + "epoch: {:3d}, avg loss:{:.4f}, total cost: {:.3f} s, per step fps:{:5.3f}".format( + cb_params.cur_epoch_num, np.mean(self.losses), epoch_cost, step_fps + ), + flush=True, + ) + + +def filter_checkpoint_parameter_by_list(param_dict, filter_list): + """remove useless parameters according to filter_list""" + for key in list(param_dict.keys()): + for name in filter_list: + if name in key: + print("Delete parameter from checkpoint: ", key) + del param_dict[key] + break + + +def make_grid(inputs): + """get 2D grid""" + batch_size, _, height, width = inputs.shape + xx = np.arange(0, width).reshape(1, -1) + xx = np.tile(xx, (height, 1)) + yy = np.arange(0, height).reshape(-1, 1) + yy = np.tile(yy, (1, width)) + xx = xx.reshape(1, 1, height, width) + xx = np.tile(xx, (batch_size, 1, 1, 1)) + yy = yy.reshape(1, 1, height, width) + yy = np.tile(yy, (batch_size, 1, 1, 1)) + grid = np.concatenate((xx, yy), axis=1).astype(np.float32) + return grid + + +def warp(inputs, flow, grid, mode="bilinear", padding_mode="zeros"): + width = inputs.shape[-1] + vgrid = grid + flow + vgrid = 2.0 * vgrid / max(width - 1, 1) - 1.0 + vgrid = vgrid.transpose(0, 2, 3, 1) + output = ops.grid_sample( + inputs, vgrid, padding_mode=padding_mode, mode=mode, align_corners=True + ) + return output + + +def readbin(filename, size, precision="real*4", skip=0, endianness="ieee-be"): + """write a ndarray into binary file for MITgcm.""" + if endianness == "ieee-be": + df_part1 = ">" + elif endianness == "ieee-le": + df_part1 = "<" + else: + print("Error endianness!") + sys.exit(1) + + if precision == "real*4": + df_part2 = "f" + length = 4 + elif precision == "real*8": + df_part2 = "d" + length = 8 + else: + print("Error precision!") + sys.exit(1) + + dataformat = df_part1 + str(np.prod(size)) + df_part2 + fout = open(filename, "rb") + if skip != 0: + fout.seek(np.prod(size) * length * skip) + data = struct.unpack(dataformat, fout.read(length * np.prod(size))) + fout.close() + return np.reshape(data, size, order="F") + + +def writebin(filename, ndarray, precision="real*4", skip=0, endianness="ieee-be"): + """write a ndarray into binary file for MITgcm.""" + size = np.prod(ndarray.shape) + arraycol = np.reshape(ndarray, (size, 1), order="F") + if endianness == "ieee-be": + df_part1 = ">" + elif endianness == "ieee-le": + df_part1 = "<" + else: + print("Error endianness!") + sys.exit(1) + + if precision == "real*4": + df_part2 = "f" + length = 4 + elif precision == "real*8": + df_part2 = "d" + length = 8 + else: + print("Error precision!") + sys.exit(1) + + dataformat = df_part1 + str(size) + df_part2 + + if os.path.isfile(filename): + fout = open(filename, "r+b") + else: + fout = open(filename, "wb") + + fout.seek(size * skip * length, 0) + fout.write(struct.pack(dataformat, *arraycol)) + fout.close()