diff --git a/MindEarth/applications/nowcasting/PreDiff/README.md b/MindEarth/applications/nowcasting/PreDiff/README.md new file mode 100644 index 0000000000000000000000000000000000000000..7d0c1c7b20bd554b0dc40ce3c3a2bfff60ffc5f8 --- /dev/null +++ b/MindEarth/applications/nowcasting/PreDiff/README.md @@ -0,0 +1,69 @@ +# PreDiff: Short-Term Precipitation Forecasting Based on Latent Diffusion Models + +## Overview + +Traditional weather forecasting techniques rely on complex physical models. These models not only incur high computational costs but also require in-depth professional knowledge. However, in the past decade, with the explosive growth of Earth's spatio-temporal observation data, deep learning techniques have opened up new avenues for building data-driven prediction models. Although these models have demonstrated great potential in various Earth system prediction tasks, they still fall short in managing uncertainties and integrating domain-specific prior knowledge, often resulting in vague or physically implausible prediction results. + +To overcome these challenges, Gao Zhihan from the Hong Kong University of Science and Technology implemented the **PreDiff** model, specifically designed for probabilistic spatio-temporal prediction. This process integrates a conditional latent diffusion model with an explicit knowledge alignment mechanism, aiming to generate prediction results that adhere to domain-specific physical constraints and accurately capture spatio-temporal changes. Through this approach, we expect to significantly improve the accuracy and reliability of Earth system predictions. On this basis, refined results are generated to obtain the final precipitation forecast. The model framework diagram is shown below (the image is sourced from the paper [PreDiff: Precipitation Nowcasting with Latent Diffusion Models](https://openreview.net/pdf?id=Gh67ZZ6zkS)). + +![prediff](images/train.jpg) + +During the training process, data extracts key information into the latent space through the variational autoencoder, then randomly selects time steps to generate corresponding noise for the data, performing noise injection. The noisy data is then input into Earthformer-UNet for denoising, which utilizes a UNet architecture and cuboid attention while removing the cross-attention structure connecting the encoder and decoder in Earthformer. Finally, the results are restored to obtain the denoised data through a variational autoencoder. The diffusion model learns the data distribution by reversing the predefined noise injection process that corrupts the original data. + +This tutorial demonstrates how to train and perform rapid inference on the model using MindEarth. For more information, refer to the [article](https://openreview.net/pdf?id=Gh67ZZ6zkS). +In this tutorial, the open-source dataset [SEVIR-LR](https://deep-earth.s3.amazonaws.com/datasets/sevir_lr.zip) is used for training and inference. + +You can download the required checkpoint (ckpt) files for training and inference at [ckpt](https://download-mindspore.osinfra.cn/mindscience/mindearth/dataset/PreDiff/) + +## Quick Start + +Download and save the data from [PreDiff/dataset](https://deep-earth.s3.amazonaws.com/datasets/sevir_lr.zip), and then modify the `root_dir` path in `./configs/diffusion_cfg.yaml`. + +### Running Method: Invoke the `main` script from the command line + +```python + +python main.py --device_id 0 --device_target Ascend --cfg ./configs/diffusion_cfg.yaml --mode train + +``` + +Here, `--device_target` indicates the device type, with the default being Ascend. `--device_id` indicates the number of the running device, with the default value being 0. `--cfg` is the path to the configuration file, with the default value being `./configs/diffusion_cfg.yaml`. `--mode` is the running mode, with the default value being `train`. + +### Inference + +Set `ckpt_path` in `./configs/diffusion_cfg.yaml` to the ckpt address of the diffusion model. + +```python + +python main.py --device_id 0 --mode test + +``` + +### Results Presentation + +#### Visualization of Prediction Results + +The following figure shows the results of inference after training with approximately 20,000 samples for 2000 epochs. + +![diffusion](images\diffusion_result.png) + +### Performance + +| Parameter | NPU | +|:----------------------:|:--------------------------:| +| Hardware | Ascend,64G | +| MindSpore Version | 2.5.0 | +| Dataset | SEVIR-LR | +| Training Parameters | batch_size=1, steps_per_epoch=24834, epochs=5 | +| Testing Parameters | batch_size=1, steps=2500 | +| Optimizer | AdamW | +| Training Loss (MSE) | 0.0857 | +| POD robability of Detection(16h) | 0.50 | +| Training Resources | 1 Node, 1 NPU | +| Training Speed (ms/step) | 3000 ms | + +## Contributors + +Gitee ID: funfunplus + +Email: funniless@163.com \ No newline at end of file diff --git a/MindEarth/applications/nowcasting/PreDiff/README_CN.md b/MindEarth/applications/nowcasting/PreDiff/README_CN.md new file mode 100644 index 0000000000000000000000000000000000000000..49904f6afe0edd00fd4b45346aa43ebbe70c4f16 --- /dev/null +++ b/MindEarth/applications/nowcasting/PreDiff/README_CN.md @@ -0,0 +1,70 @@ +# PreDiff: 基于潜在扩散模型的降水短时预报 + +## 概述 + +传统的天气预报技术依赖于复杂的物理模型,这些模型不仅计算成本高昂,还要求深厚的专业知识支撑。然而,近十年来,随着地球时空观测数据的爆炸式增长,深度学习技术为构建数据驱动的预测模型开辟了新的道路。虽然这些模型在多种地球系统预测任务中展现出巨大潜力,但它们在管理不确定性和整合特定领域先验知识方面仍有不足,时常导致预测结果模糊不清或在物理上不可信。 + +为克服这些难题,来自香港科技大学的Gao Zhihan实现了**prediff**模型,专门用于实现概率性的时空预测。该流程融合了条件潜在扩散模型与显式的知识对齐机制,旨在生成既符合特定领域物理约束,又能精确捕捉时空变化的预测结果。通过这种方法,我们期望能够显著提升地球系统预测的准确性和可靠性。 +基础上生成精细化的结果,从而得到最终的降水预报。模型框架图入下图所示(图片来源于论文 [PreDiff: Precipitation Nowcasting with Latent Diffusion Models](https://openreview.net/pdf?id=Gh67ZZ6zkS)) + +![prediff](images/train.jpg) + +训练的过程中,数据通过变分自编码器提取关键信息到隐空间,之后随机选择时间步生成对应时间步噪声,对数据进行加噪处理。之后将数据输入Earthformer-UNet进行降噪处理,Earthformer-UNet采用了UNet构架和cuboid attention,去除了Earthformer中连接encoder和decoder的cross-attention结构。最后将结果通过变分自解码器还原得到去噪后的数据,扩散模型通过反转预先定义的破坏原始数据的加噪过程来学习数据分布。 + +本教程展示了如何通过MindEarth训练和快速推理模型。更多信息参见[文章](https://openreview.net/pdf?id=Gh67ZZ6zkS) +本教程中使用开源数据集[SEVIR-LR](https://deep-earth.s3.amazonaws.com/datasets/sevir_lr.zip)数据集进行训练和推理。 + +可以在[ckpt](https://download-mindspore.osinfra.cn/mindscience/mindearth/dataset/PreDiff/)下载训练和推理所需要的ckpt + +## 快速开始 + +在[PreDiff/dataset](https://deep-earth.s3.amazonaws.com/datasets/sevir_lr.zip)下载数据并保存,然后在`./configs/diffusion_cfg.yaml`中修改`root_dir`路径。 + +### 运行方式: 在命令行调用`main`脚本 + +```python + +python main.py --device_id 0 --device_target Ascend --cfg ./configs/diffusion_cfg.yaml --mode train + +``` + +其中, --device_target 表示设备类型,默认Ascend。 --device_id 表示运行设备的编号,默认值0。 --cfg 配置文件路径,默认值"./configs/diffusion_cfg.yaml"。 --mode 运行模式,默认值train + +### 推理 + +在`./configs/diffusion_cfg.yaml`中设置`ckpt_path`为diffusion模型ckpt地址。 + +```python + +python main.py --device_id 0 --mode test + +``` + +### 结果展示: + +#### 预测结果可视化 + +下图展示了使用约2w条样本训练2000个epoch后进行推理绘制的结果。 + +![diffusion](images/diffusion_result.png) + +### 性能 + +| Parameter | NPU | +|:----------------------:|:--------------------------:| +| 硬件版本 | Ascend, 64G | +| mindspore版本 | 2.5.0 | +| 数据集 | SEVIR-LR | +| 训练参数 | batch_size=1, steps_per_epoch=24834, epochs=5 | +| 测试参数 | batch_size=1,steps=2500 | +| 优化器 | AdamW | +| 训练损失(MSE) | 0.0857 | +| POD预测命中率(16h) | 0.50 | +| 训练资源 | 1Node 1NPU | +| 训练速度(ms/step) | 3000ms | + +## 贡献者 + +gitee id: funfunplus + +email: funniless@163.com \ No newline at end of file diff --git a/MindEarth/applications/nowcasting/PreDiff/configs/diffusion.yaml b/MindEarth/applications/nowcasting/PreDiff/configs/diffusion.yaml new file mode 100644 index 0000000000000000000000000000000000000000..408dfe6f013d21af1b9292c6fc461622bfef82b6 --- /dev/null +++ b/MindEarth/applications/nowcasting/PreDiff/configs/diffusion.yaml @@ -0,0 +1,193 @@ +data: + dataset_name: "sevirlr" + seq_in: 13 + plot_stride: 1 + interval_real_time: 10 + raw_seq_len: 25 + sample_mode: "sequent" + stride: 6 + layout: "NTHWC" + start_date: null + train_val_split_date: [2019, 3, 19] + train_test_split_date: [2019, 6, 1] + end_date: null + val_ratio: 0.1 + metrics_mode: "0" + metrics_list: ['csi', 'pod', 'sucr', 'bias'] + threshold_list: [16, 74, 133, 160, 181, 219] + aug_mode: "0" + root_dir: "./dataset/sevir_lr" +layout: + t_in: 7 + t_out: 6 + data_channels: 1 + layout: "NTHWC" +optim: + total_batch_size: 64 + micro_batch_size: 2 + seed: 0 + float32_matmul_precision: "high" + method: "adamw" + lr: 1.0e-5 + betas: [0.9, 0.999] + gradient_clip_val: 1.0 + max_epochs: 2000 + loss_type: "l2" + warmup_percentage: 0.1 + lr_scheduler_mode: "cosine" + min_lr_ratio: 1.0e-3 + warmup_min_lr_ratio: 0.1 + monitor: "val/loss" + early_stop: false + early_stop_mode: "min" + early_stop_patience: 100 + save_top_k: 3 +logging: + logging_prefix: "PreDiff" + monitor_lr: true + monitor_device: false + track_grad_norm: -1 + use_wandb: false + profiler: null +trainer: + check_val_every_n_epoch: 50 + log_step_ratio: 0.001 + precision: 32 + find_unused_parameters: false +eval: + train_example_data_idx_list: [0, ] + val_example_data_idx_list: [0, 16, 32, 48, 64, 72, 96, 108, 128] + test_example_data_idx_list: [0, 16, 32, 48, 64, 72, 96, 108, 128] + eval_example_only: true + eval_aligned: true + eval_unaligned: true + num_samples_per_context: 1 + fs: 20 + label_offset: [-0.5, 0.5] + label_avg_int: false + fvd_features: 400 +model: + diffusion: + data_shape: [6, 128, 128, 1] + timesteps: 1000 + beta_schedule: "linear" + log_every_t: 100 + clip_denoised: false + linear_start: 1e-4 + linear_end: 2e-2 + cosine_s: 8e-3 + given_betas: null + original_elbo_weight: 0. + v_posterior: 0. + l_simple_weight: 1. + learn_logvar: false + logvar_init: 0. + latent_shape: [6, 16, 16, 64] + cond_stage_forward: null + scale_by_std: false + scale_factor: 1.0 + latent_cond_shape: [7, 16, 16, 64] + align: + alignment_type: "avg_x" + guide_scale: 50.0 + model_type: "cuboid" + model_args: + input_shape: [6, 16, 16, 64] + out_channels: 1 + base_units: 128 + scale_alpha: 1.0 + depth: [1, 1] + downsample: 2 + downsample_type: "patch_merge" + use_attn_pattern: true + num_heads: 4 + attn_drop: 0.1 + proj_drop: 0.1 + ffn_drop: 0.1 + ffn_activation: "gelu" + gated_ffn: false + norm_layer: "layer_norm" + use_inter_ffn: true + hierarchical_pos_embed: false + padding_type: "zeros" + use_relative_pos: true + self_attn_use_final_proj: true + num_global_vectors: 0 + use_global_vector_ffn: true + use_global_self_attn: false + separate_global_qkv: false + global_dim_ratio: 1 + attn_linear_init_mode: "0" + ffn_linear_init_mode: "0" + ffn2_linear_init_mode: "2" + attn_proj_linear_init_mode: "2" + conv_init_mode: "0" + down_linear_init_mode: "0" + global_proj_linear_init_mode: "2" + norm_init_mode: "0" + time_embed_channels_mult: 4 + time_embed_use_scale_shift_norm: false + time_embed_dropout: 0.0 + pool: "attention" + readout_seq: true + t_out: 6 + model_ckpt_path: "./ckpt/align.ckpt" + latent_model: + input_shape: [7, 16, 16, 64] + target_shape: [6, 16, 16, 64] + base_units: 256 + block_units: Null + scale_alpha: 1.0 + num_heads: 4 + attn_drop: 0.1 + proj_drop: 0.1 + ffn_drop: 0.1 + downsample: 2 + downsample_type: "patch_merge" + upsample_type: "upsample" + upsample_kernel_size: 3 + depth: [4, 4] + use_attn_pattern: true + num_global_vectors: 0 + use_global_vector_ffn: false + use_global_self_attn: true + separate_global_qkv: true + global_dim_ratio: 1 + ffn_activation: "gelu" + gated_ffn: false + norm_layer: "layer_norm" + padding_type: "zeros" + use_relative_pos: true + self_attn_use_final_proj: true + attn_linear_init_mode: "0" + ffn_linear_init_mode: "0" + ffn2_linear_init_mode: "2" + attn_proj_linear_init_mode: "2" + conv_init_mode: "0" + down_linear_init_mode: "0" + global_proj_linear_init_mode: "2" + norm_init_mode: "0" + time_embed_channels_mult: 4 + time_embed_use_scale_shift_norm: false + time_embed_dropout: 0.0 + unet_res_connect: true + vae: + pretrained_ckpt_path: "./ckpt/vae.ckpt" + data_channels: 1 + down_block_types: ['DownEncoderBlock2D', 'DownEncoderBlock2D', 'DownEncoderBlock2D', 'DownEncoderBlock2D'] + in_channels: 1 + block_out_channels: [128, 256, 512, 512] + act_fn: 'silu' + latent_channels: 64 + up_block_types: ['UpDecoderBlock2D', 'UpDecoderBlock2D', 'UpDecoderBlock2D', 'UpDecoderBlock2D'] + norm_num_groups: 32 + layers_per_block: 2 + out_channels: 1 +summary: + summary_dir: "./summary/prediff" + eval_interval: 10 + save_ckpt_epochs: 1 + keep_ckpt_max: 100 + ckpt_path: "./ckpt/diffusion.ckpt" + load_ckpt: false + diff --git a/MindEarth/applications/nowcasting/PreDiff/images/diffusion_result.png b/MindEarth/applications/nowcasting/PreDiff/images/diffusion_result.png new file mode 100644 index 0000000000000000000000000000000000000000..7bfe96978e3496305c62a3301d34826dd932521f Binary files /dev/null and b/MindEarth/applications/nowcasting/PreDiff/images/diffusion_result.png differ diff --git a/MindEarth/applications/nowcasting/PreDiff/images/train.jpg b/MindEarth/applications/nowcasting/PreDiff/images/train.jpg new file mode 100644 index 0000000000000000000000000000000000000000..147412d83723f0f0203681a30152774b5895bc68 Binary files /dev/null and b/MindEarth/applications/nowcasting/PreDiff/images/train.jpg differ diff --git a/MindEarth/applications/nowcasting/PreDiff/main.py b/MindEarth/applications/nowcasting/PreDiff/main.py new file mode 100644 index 0000000000000000000000000000000000000000..e430ab0ae3ee3303e2fce4604df6c322cb7903cb --- /dev/null +++ b/MindEarth/applications/nowcasting/PreDiff/main.py @@ -0,0 +1,77 @@ +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"diffusion main" +import argparse + +import mindspore as ms +from mindspore import set_seed, context +from mindearth.utils import load_yaml_config + +from src import ( + prepare_output_directory, + configure_logging_system, + prepare_dataset, + init_model, + PreDiffModule, + DiffusionTrainer, + DiffusionInferrence +) + + +set_seed(0) + + +def get_parser(): + parser = argparse.ArgumentParser() + parser.add_argument("--device_id", default=0, type=int) + parser.add_argument("--device_target", default="Ascend", type=str) + parser.add_argument('--cfg', default="./configs/diffusion.yaml", type=str) + parser.add_argument("--mode", default="train") + params = parser.parse_args() + return params + + +def train(cfg, arg, module): + output_dir = prepare_output_directory(cfg, arg.device_id) + logger = configure_logging_system(output_dir, cfg) + dm, total_num_steps = prepare_dataset(cfg, PreDiffModule) + trainer = DiffusionTrainer( + main_module=module, dm=dm, logger=logger, config=cfg + ) + trainer.train(total_steps=total_num_steps) + + +def test(cfg, arg, module): + output_dir = prepare_output_directory(cfg, arg.device_id) + logger = configure_logging_system(output_dir, cfg) + dm, _ = prepare_dataset(cfg, PreDiffModule) + tester = DiffusionInferrence( + main_module=module, dm=dm, logger=logger, config=cfg + ) + tester.test() + + +if __name__ == "__main__": + args = get_parser() + config = load_yaml_config(args.cfg) + context.set_context(mode=ms.PYNATIVE_MODE) + ms.set_device(device_target=args.device_target, device_id=args.device_id) + main_module = PreDiffModule(oc_file=args.cfg) + main_module = init_model(module=main_module, config=config, mode=args.mode) + if args.mode == "train": + train(config, args, main_module) + else: + test(config, args, main_module) + \ No newline at end of file diff --git a/MindEarth/applications/nowcasting/PreDiff/prediff.ipynb b/MindEarth/applications/nowcasting/PreDiff/prediff.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..777d8c62a0b11c7ba569b0cdbffc6aaefcb548b1 --- /dev/null +++ b/MindEarth/applications/nowcasting/PreDiff/prediff.ipynb @@ -0,0 +1,1783 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "bb7af2dd-61ff-4033-bc04-3a1ec598f5c4", + "metadata": {}, + "source": [ + "# PreDiff: 基于潜在扩散模型的降水短时预报\n", + "\n", + "## 概述\n", + "\n", + "传统的天气预报技术依赖于复杂的物理模型,这些模型不仅计算成本高昂,还要求深厚的专业知识支撑。然而,近十年来,随着地球时空观测数据的爆炸式增长,深度学习技术为构建数据驱动的预测模型开辟了新的道路。虽然这些模型在多种地球系统预测任务中展现出巨大潜力,但它们在管理不确定性和整合特定领域先验知识方面仍有不足,时常导致预测结果模糊不清或在物理上不可信。\n", + "\n", + "为克服这些难题,来自香港科技大学的Gao Zhihan实现了**prediff**模型,专门用于实现概率性的时空预测。该流程融合了条件潜在扩散模型与显式的知识对齐机制,旨在生成既符合特定领域物理约束,又能精确捕捉时空变化的预测结果。通过这种方法,我们期望能够显著提升地球系统预测的准确性和可靠性。\n", + "基础上生成精细化的结果,从而得到最终的降水预报。模型框架图入下图所示(图片来源于论文 [PreDiff: Precipitation Nowcasting with Latent Diffusion Models](https://openreview.net/pdf?id=Gh67ZZ6zkS))\n", + "\n", + "![prediff](images/train.jpg)\n", + "\n", + "训练的过程中,数据通过变分自编码器提取关键信息到隐空间,之后随机选择时间步生成对应时间步噪声,对数据进行加噪处理。之后将数据输入Earthformer-UNet进行降噪处理,Earthformer-UNet采用了UNet构架和cuboid attention,去除了Earthformer中连接encoder和decoder的cross-attention结构。最后将结果通过变分自解码器还原得到去噪后的数据,扩散模型通过反转预先定义的破坏原始数据的加噪过程来学习数据分布。" + ] + }, + { + "cell_type": "markdown", + "id": "d3ca5c0b-4fe8-4764-b046-b154df49cc9b", + "metadata": {}, + "source": [ + "## 概述\n", + "\n", + "MindSpore Earth求解该问题的具体流程如下:\n", + "\n", + "1.创建数据集\n", + "\n", + "2.模型构建\n", + "\n", + "3.损失函数\n", + "\n", + "4.模型训练\n", + "\n", + "5.模型评估与可视化\n", + "\n", + "数据集可以在[PreDiff/dataset](https://deep-earth.s3.amazonaws.com/datasets/sevir_lr.zip)下载数据并保存" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "56631948-30aa-4360-84d1-e42bff3ab87c", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/data/zmmVol1/miniconda3/envs/lryms25py311/lib/python3.11/site-packages/numpy/core/getlimits.py:549: UserWarning: The value of the smallest subnormal for type is zero.\n", + " setattr(self, word, getattr(machar, word).flat[0])\n", + "/data/zmmVol1/miniconda3/envs/lryms25py311/lib/python3.11/site-packages/numpy/core/getlimits.py:89: UserWarning: The value of the smallest subnormal for type is zero.\n", + " return self._float_to_str(self.smallest_subnormal)\n", + "/data/zmmVol1/miniconda3/envs/lryms25py311/lib/python3.11/site-packages/numpy/core/getlimits.py:549: UserWarning: The value of the smallest subnormal for type is zero.\n", + " setattr(self, word, getattr(machar, word).flat[0])\n", + "/data/zmmVol1/miniconda3/envs/lryms25py311/lib/python3.11/site-packages/numpy/core/getlimits.py:89: UserWarning: The value of the smallest subnormal for type is zero.\n", + " return self._float_to_str(self.smallest_subnormal)\n" + ] + } + ], + "source": [ + "import time\n", + "import os\n", + "import random\n", + "import json\n", + "from typing import Sequence, Union\n", + "import numpy as np\n", + "from einops import rearrange\n", + "\n", + "import mindspore as ms\n", + "from mindspore import set_seed, context, ops, nn, mint\n", + "from mindspore.experimental import optim\n", + "from mindspore.train.serialization import save_checkpoint" + ] + }, + { + "cell_type": "markdown", + "id": "415eb386-b9ef-42af-9ab6-f98c9d8151da", + "metadata": {}, + "source": [ + "下述src可以在[PreDiff/src](./src)下载" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "3ab31256-d802-4c68-9066-6c2cc9e73dcd", + "metadata": {}, + "outputs": [], + "source": [ + "from mindearth.utils import load_yaml_config\n", + "\n", + "from src import (\n", + " prepare_output_directory,\n", + " configure_logging_system,\n", + " prepare_dataset,\n", + " init_model,\n", + " PreDiffModule,\n", + " DiffusionTrainer,\n", + " DiffusionInferrence\n", + ")\n", + "from src.sevir_dataset import SEVIRDataset\n", + "from src.visual import vis_sevir_seq\n", + "from src.utils import warmup_lambda" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "e2576e2e-b98f-4791-8634-d68e55531ede", + "metadata": {}, + "outputs": [], + "source": [ + "set_seed(0)\n", + "np.random.seed(0)\n", + "random.seed(0)" + ] + }, + { + "cell_type": "markdown", + "id": "7272ed00-61ef-439c-a420-b3bcefc13965", + "metadata": {}, + "source": [ + "可以在[配置文件](./configs/diffusion.yaml)中配置模型、数据和优化器等参数。" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "6c2a2194-7a48-4412-b4d8-9bcae2d5c280", + "metadata": {}, + "outputs": [], + "source": [ + "config = load_yaml_config(\"./configs/diffusion.yaml\")\n", + "context.set_context(mode=ms.PYNATIVE_MODE)\n", + "ms.set_device(device_target=\"Ascend\", device_id=0)" + ] + }, + { + "cell_type": "markdown", + "id": "9e4b94cd-f806-4e90-bf67-030e66253274", + "metadata": {}, + "source": [ + "## 模型构建\n", + "\n", + "模型初始化主要包括vae模块load ckpt以及earthformer部分初始化" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "36dabe74-de56-40c2-9db6-370ad2d7a0fa", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[WARNING] ME(2231351:281473150263328,MainProcess):2025-04-07-10:32:09.988.000 [mindspore/train/serialization.py:1956] For 'load_param_into_net', remove parameter prefix name: net., continue to load.\n", + "[WARNING] ME(2231351:281473150263328,MainProcess):2025-04-07-10:32:11.431.000 [mindspore/train/serialization.py:1956] For 'load_param_into_net', remove parameter prefix name: main_model., continue to load.\n", + "2025-04-07 10:32:11,466 - utils.py[line:820] - INFO: Process ID: 2231351\n", + "2025-04-07 10:32:11,467 - utils.py[line:821] - INFO: {'summary_dir': './summary/prediff/single_device0', 'eval_interval': 10, 'save_ckpt_epochs': 1, 'keep_ckpt_max': 100, 'ckpt_path': '/home/lry/202542测试/PreDiff/ckpt/diffusion.ckpt', 'load_ckpt': False}\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "NoisyCuboidTransformerEncoder param_not_load: []\n", + "Cleared previous output directory: ./summary/prediff/single_device0\n" + ] + } + ], + "source": [ + "main_module = PreDiffModule(oc_file=\"./configs/diffusion.yaml\")\n", + "main_module = init_model(module=main_module, config=config, mode=\"train\")\n", + "output_dir = prepare_output_directory(config, \"0\")\n", + "logger = configure_logging_system(output_dir, config)" + ] + }, + { + "cell_type": "markdown", + "id": "ef546194-7eef-46a6-8ffa-504d4b58fa25", + "metadata": {}, + "source": [ + "## 创建数据集\n", + "\n", + "下载[sevir-lr](https://deep-earth.s3.amazonaws.com/datasets/sevir_lr.zip)数据集到./dataset目录。" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "d4b45562-49b1-4ab3-ab52-bad420c30236", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "train\n", + " vil_filename \\\n", + "id \n", + "R18020113057733 0 vil/2018/SEVIR_VIL_RANDOMEVENTS_2018_0101_0430.h5 \n", + "R18020113057811 0 vil/2018/SEVIR_VIL_RANDOMEVENTS_2018_0101_0430.h5 \n", + "R18020113057875 0 vil/2018/SEVIR_VIL_RANDOMEVENTS_2018_0101_0430.h5 \n", + "R18020113057888 0 vil/2018/SEVIR_VIL_RANDOMEVENTS_2018_0101_0430.h5 \n", + "R18020113057982 0 vil/2018/SEVIR_VIL_RANDOMEVENTS_2018_0101_0430.h5 \n", + "R18020113058079 0 vil/2018/SEVIR_VIL_RANDOMEVENTS_2018_0101_0430.h5 \n", + "R18020113058477 0 vil/2018/SEVIR_VIL_RANDOMEVENTS_2018_0101_0430.h5 \n", + "R18020113058635 0 vil/2018/SEVIR_VIL_RANDOMEVENTS_2018_0101_0430.h5 \n", + "R18020306327357 0 vil/2018/SEVIR_VIL_RANDOMEVENTS_2018_0101_0430.h5 \n", + "R18020306327410 0 vil/2018/SEVIR_VIL_RANDOMEVENTS_2018_0101_0430.h5 \n", + "\n", + " vil_index \n", + "id \n", + "R18020113057733 0 1310 \n", + "R18020113057811 0 1311 \n", + "R18020113057875 0 1312 \n", + "R18020113057888 0 1309 \n", + "R18020113057982 0 1308 \n", + "R18020113058079 0 1306 \n", + "R18020113058477 0 1313 \n", + "R18020113058635 0 1307 \n", + "R18020306327357 0 98 \n", + "R18020306327410 0 103 \n", + "len 837\n", + "hdf_filenames ['vil/2018/SEVIR_VIL_RANDOMEVENTS_2018_0101_0430.h5', 'vil/2018/SEVIR_VIL_STORMEVENTS_2018_0101_0630.h5']\n", + "Opening HDF5 file for reading vil/2018/SEVIR_VIL_RANDOMEVENTS_2018_0101_0430.h5\n", + "f: vil/2018/SEVIR_VIL_RANDOMEVENTS_2018_0101_0430.h5\n", + "self._hdf_files[f]: \n", + "Opening HDF5 file for reading vil/2018/SEVIR_VIL_STORMEVENTS_2018_0101_0630.h5\n", + "f: vil/2018/SEVIR_VIL_STORMEVENTS_2018_0101_0630.h5\n", + "self._hdf_files[f]: \n", + "val\n", + " vil_filename \\\n", + "id \n", + "R18031116527414 0 vil/2018/SEVIR_VIL_RANDOMEVENTS_2018_0101_0430.h5 \n", + "R18031116527427 0 vil/2018/SEVIR_VIL_RANDOMEVENTS_2018_0101_0430.h5 \n", + "R18031116527438 0 vil/2018/SEVIR_VIL_RANDOMEVENTS_2018_0101_0430.h5 \n", + "R18031116527505 0 vil/2018/SEVIR_VIL_RANDOMEVENTS_2018_0101_0430.h5 \n", + "R18031116527686 0 vil/2018/SEVIR_VIL_RANDOMEVENTS_2018_0101_0430.h5 \n", + "R18031116527705 0 vil/2018/SEVIR_VIL_RANDOMEVENTS_2018_0101_0430.h5 \n", + "R18031116527805 0 vil/2018/SEVIR_VIL_RANDOMEVENTS_2018_0101_0430.h5 \n", + "R18031116527860 0 vil/2018/SEVIR_VIL_RANDOMEVENTS_2018_0101_0430.h5 \n", + "R18031116577456 0 vil/2018/SEVIR_VIL_RANDOMEVENTS_2018_0101_0430.h5 \n", + "R18031116577481 0 vil/2018/SEVIR_VIL_RANDOMEVENTS_2018_0101_0430.h5 \n", + "\n", + " vil_index \n", + "id \n", + "R18031116527414 0 1684 \n", + "R18031116527427 0 1685 \n", + "R18031116527438 0 1687 \n", + "R18031116527505 0 1686 \n", + "R18031116527686 0 1681 \n", + "R18031116527705 0 1682 \n", + "R18031116527805 0 1683 \n", + "R18031116527860 0 1680 \n", + "R18031116577456 0 883 \n", + "R18031116577481 0 882 \n", + "len 9939\n", + "hdf_filenames ['vil/2018/SEVIR_VIL_RANDOMEVENTS_2018_0101_0430.h5', 'vil/2018/SEVIR_VIL_RANDOMEVENTS_2018_0501_0831.h5', 'vil/2018/SEVIR_VIL_RANDOMEVENTS_2018_0901_1231.h5', 'vil/2018/SEVIR_VIL_STORMEVENTS_2018_0101_0630.h5', 'vil/2018/SEVIR_VIL_STORMEVENTS_2018_0701_1231.h5', 'vil/2019/SEVIR_VIL_RANDOMEVENTS_2019_0101_0430.h5', 'vil/2019/SEVIR_VIL_RANDOMEVENTS_2019_0501_0831.h5', 'vil/2019/SEVIR_VIL_STORMEVENTS_2019_0101_0630.h5']\n", + "Opening HDF5 file for reading vil/2018/SEVIR_VIL_RANDOMEVENTS_2018_0101_0430.h5\n", + "f: vil/2018/SEVIR_VIL_RANDOMEVENTS_2018_0101_0430.h5\n", + "self._hdf_files[f]: \n", + "Opening HDF5 file for reading vil/2018/SEVIR_VIL_RANDOMEVENTS_2018_0501_0831.h5\n", + "f: vil/2018/SEVIR_VIL_RANDOMEVENTS_2018_0501_0831.h5\n", + "self._hdf_files[f]: \n", + "Opening HDF5 file for reading vil/2018/SEVIR_VIL_RANDOMEVENTS_2018_0901_1231.h5\n", + "f: vil/2018/SEVIR_VIL_RANDOMEVENTS_2018_0901_1231.h5\n", + "self._hdf_files[f]: \n", + "Opening HDF5 file for reading vil/2018/SEVIR_VIL_STORMEVENTS_2018_0101_0630.h5\n", + "f: vil/2018/SEVIR_VIL_STORMEVENTS_2018_0101_0630.h5\n", + "self._hdf_files[f]: \n", + "Opening HDF5 file for reading vil/2018/SEVIR_VIL_STORMEVENTS_2018_0701_1231.h5\n", + "f: vil/2018/SEVIR_VIL_STORMEVENTS_2018_0701_1231.h5\n", + "self._hdf_files[f]: \n", + "Opening HDF5 file for reading vil/2019/SEVIR_VIL_RANDOMEVENTS_2019_0101_0430.h5\n", + "f: vil/2019/SEVIR_VIL_RANDOMEVENTS_2019_0101_0430.h5\n", + "self._hdf_files[f]: \n", + "Opening HDF5 file for reading vil/2019/SEVIR_VIL_RANDOMEVENTS_2019_0501_0831.h5\n", + "f: vil/2019/SEVIR_VIL_RANDOMEVENTS_2019_0501_0831.h5\n", + "self._hdf_files[f]: \n", + "Opening HDF5 file for reading vil/2019/SEVIR_VIL_STORMEVENTS_2019_0101_0630.h5\n", + "f: vil/2019/SEVIR_VIL_STORMEVENTS_2019_0101_0630.h5\n", + "self._hdf_files[f]: \n", + "test\n", + " vil_filename \\\n", + "id \n", + "R19060101537335 0 vil/2019/SEVIR_VIL_RANDOMEVENTS_2019_0501_0831.h5 \n", + "R19060101537490 0 vil/2019/SEVIR_VIL_RANDOMEVENTS_2019_0501_0831.h5 \n", + "R19060101537544 0 vil/2019/SEVIR_VIL_RANDOMEVENTS_2019_0501_0831.h5 \n", + "R19060101537604 0 vil/2019/SEVIR_VIL_RANDOMEVENTS_2019_0501_0831.h5 \n", + "R19060101537632 0 vil/2019/SEVIR_VIL_RANDOMEVENTS_2019_0501_0831.h5 \n", + "R19060101537740 0 vil/2019/SEVIR_VIL_RANDOMEVENTS_2019_0501_0831.h5 \n", + "R19060101538445 0 vil/2019/SEVIR_VIL_RANDOMEVENTS_2019_0501_0831.h5 \n", + "R19060101538500 0 vil/2019/SEVIR_VIL_RANDOMEVENTS_2019_0501_0831.h5 \n", + "R19060103457525 0 vil/2019/SEVIR_VIL_RANDOMEVENTS_2019_0501_0831.h5 \n", + "R19060103457650 0 vil/2019/SEVIR_VIL_RANDOMEVENTS_2019_0501_0831.h5 \n", + "\n", + " vil_index \n", + "id \n", + "R19060101537335 0 2317 \n", + "R19060101537490 0 2323 \n", + "R19060101537544 0 2318 \n", + "R19060101537604 0 2316 \n", + "R19060101537632 0 2322 \n", + "R19060101537740 0 2321 \n", + "R19060101538445 0 2320 \n", + "R19060101538500 0 2319 \n", + "R19060103457525 0 1469 \n", + "R19060103457650 0 1468 \n", + "len 4053\n", + "hdf_filenames ['vil/2019/SEVIR_VIL_RANDOMEVENTS_2019_0501_0831.h5', 'vil/2019/SEVIR_VIL_RANDOMEVENTS_2019_0901_1231.h5', 'vil/2019/SEVIR_VIL_STORMEVENTS_2019_0101_0630.h5', 'vil/2019/SEVIR_VIL_STORMEVENTS_2019_0701_1231.h5']\n", + "Opening HDF5 file for reading vil/2019/SEVIR_VIL_RANDOMEVENTS_2019_0501_0831.h5\n", + "f: vil/2019/SEVIR_VIL_RANDOMEVENTS_2019_0501_0831.h5\n", + "self._hdf_files[f]: \n", + "Opening HDF5 file for reading vil/2019/SEVIR_VIL_RANDOMEVENTS_2019_0901_1231.h5\n", + "f: vil/2019/SEVIR_VIL_RANDOMEVENTS_2019_0901_1231.h5\n", + "self._hdf_files[f]: \n", + "Opening HDF5 file for reading vil/2019/SEVIR_VIL_STORMEVENTS_2019_0101_0630.h5\n", + "f: vil/2019/SEVIR_VIL_STORMEVENTS_2019_0101_0630.h5\n", + "self._hdf_files[f]: \n", + "Opening HDF5 file for reading vil/2019/SEVIR_VIL_STORMEVENTS_2019_0701_1231.h5\n", + "f: vil/2019/SEVIR_VIL_STORMEVENTS_2019_0701_1231.h5\n", + "self._hdf_files[f]: \n" + ] + } + ], + "source": [ + "dm, total_num_steps = prepare_dataset(config, PreDiffModule)" + ] + }, + { + "cell_type": "markdown", + "id": "0fc4c0ce-9870-43f5-931b-30c8b5705122", + "metadata": {}, + "source": [ + "## 损失函数\n", + "\n", + "PreDiff训练中使用mse作为loss计算,采用了梯度裁剪,并将过程封装在了DiffusionTrainer中" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "4e8ff7ff-3641-4a1a-b412-6587c5b09562", + "metadata": {}, + "outputs": [], + "source": [ + "class DiffusionTrainer(nn.Cell):\n", + " \"\"\"\n", + " Class managing the training pipeline for diffusion models. Handles dataset processing,\n", + " optimizer configuration, gradient clipping, checkpoint saving, and logging.\n", + " \"\"\"\n", + " def __init__(self, main_module, dm, logger, config):\n", + " \"\"\"\n", + " Initialize trainer with model, data module, logger, and configuration.\n", + " Args:\n", + " main_module: Main diffusion model to be trained\n", + " dm: Data module providing training dataset\n", + " logger: Logging utility for training progress\n", + " config: Configuration dictionary containing hyperparameters\n", + " \"\"\"\n", + " super().__init__()\n", + " self.main_module = main_module\n", + " self.traindataset = dm.sevir_train\n", + " self.logger = logger\n", + " self.datasetprocessing = SEVIRDataset(\n", + " data_types=[\"vil\"],\n", + " layout=\"NHWT\",\n", + " rescale_method=config.get(\"rescale_method\", \"01\"),\n", + " )\n", + " self.example_save_dir = config[\"summary\"].get(\"summary_dir\", \"./summary\")\n", + " self.fs = config[\"eval\"].get(\"fs\", 20)\n", + " self.label_offset = config[\"eval\"].get(\"label_offset\", [-0.5, 0.5])\n", + " self.label_avg_int = config[\"eval\"].get(\"label_avg_int\", False)\n", + " self.current_epoch = 0\n", + " self.learn_logvar = (\n", + " config.get(\"model\", {}).get(\"diffusion\", {}).get(\"learn_logvar\", False)\n", + " )\n", + " self.logvar = main_module.logvar\n", + " self.maeloss = nn.MAELoss()\n", + " self.optim_config = config[\"optim\"]\n", + " self.clip_norm = config.get(\"clip_norm\", 2.0)\n", + " self.ckpt_dir = os.path.join(self.example_save_dir, \"ckpt\")\n", + " self.keep_ckpt_max = config[\"summary\"].get(\"keep_ckpt_max\", 100)\n", + " self.ckpt_history = []\n", + " self.grad_clip_fn = ops.clip_by_global_norm\n", + " self.optimizer = nn.Adam(params=self.main_module.main_model.trainable_params(), learning_rate=0.00001)\n", + " os.makedirs(self.ckpt_dir, exist_ok=True)\n", + "\n", + " def train(self, total_steps: int):\n", + " \"\"\"Execute complete training pipeline.\"\"\"\n", + " self.main_module.main_model.set_train(True)\n", + " self.logger.info(\"Initializing training process...\")\n", + " # optimizer, lr_scheduler = self._get_optimizer(total_steps)\n", + " loss_processor = Trainonestepforward(self.main_module)\n", + " grad_func = ms.ops.value_and_grad(loss_processor, None, self.main_module.main_model.trainable_params())\n", + " for epoch in range(self.optim_config[\"max_epochs\"]):\n", + " epoch_loss = 0.0\n", + " epoch_start = time.time()\n", + "\n", + " iterator = self.traindataset.create_dict_iterator()\n", + " assert iterator, \"dataset is empty\"\n", + " batch_idx = 0\n", + " for batch_idx, batch in enumerate(iterator):\n", + " processed_data = self.datasetprocessing.process_data(batch[\"vil\"])\n", + " loss_value, gradients = grad_func(processed_data)\n", + " clipped_grads = self.grad_clip_fn(gradients, self.clip_norm)\n", + " self.optimizer(clipped_grads)\n", + " #lr_scheduler.step()\n", + " epoch_loss += loss_value.asnumpy()\n", + " self.logger.info(\n", + " f\"epoch: {epoch} step: {batch_idx}, loss: {loss_value}\"\n", + " )\n", + " self._save_ckpt(epoch)\n", + " epoch_time = time.time() - epoch_start\n", + " self.logger.info(\n", + " f\"Epoch {epoch} completed in {epoch_time:.2f}s | \"\n", + " f\"Avg Loss: {epoch_loss/(batch_idx+1):.4f}\"\n", + " )\n", + "\n", + " def _get_optimizer(self, total_steps: int):\n", + " \"\"\"Configure optimization components\"\"\"\n", + " trainable_params = list(self.main_module.main_model.trainable_params())\n", + " if self.learn_logvar:\n", + " self.logger.info(\"Including log variance parameters\")\n", + " trainable_params.append(self.logvar)\n", + " optimizer = optim.AdamW(\n", + " trainable_params,\n", + " lr=self.optim_config[\"lr\"],\n", + " betas=tuple(self.optim_config[\"betas\"]),\n", + " )\n", + " warmup_steps = int(self.optim_config[\"warmup_percentage\"] * total_steps)\n", + " scheduler = self._create_lr_scheduler(optimizer, total_steps, warmup_steps)\n", + "\n", + " return optimizer, scheduler\n", + "\n", + " def _create_lr_scheduler(self, optimizer, total_steps: int, warmup_steps: int):\n", + " \"\"\"Build learning rate scheduler\"\"\"\n", + " warmup_scheduler = optim.lr_scheduler.LambdaLR(\n", + " optimizer,\n", + " lr_lambda=warmup_lambda(\n", + " warmup_steps=warmup_steps,\n", + " min_lr_ratio=self.optim_config[\"warmup_min_lr_ratio\"],\n", + " ),\n", + " )\n", + "\n", + " cosine_scheduler = optim.lr_scheduler.CosineAnnealingLR(\n", + " optimizer,\n", + " T_max=total_steps - warmup_steps,\n", + " eta_min=self.optim_config[\"min_lr_ratio\"] * self.optim_config[\"lr\"],\n", + " )\n", + "\n", + " return optim.lr_scheduler.SequentialLR(\n", + " optimizer,\n", + " schedulers=[warmup_scheduler, cosine_scheduler],\n", + " milestones=[warmup_steps],\n", + " )\n", + "\n", + " def _save_ckpt(self, epoch: int):\n", + " \"\"\"Save model ckpt with rotation policy\"\"\"\n", + " ckpt_file = f\"diffusion_epoch{epoch}.ckpt\"\n", + " ckpt_path = os.path.join(self.ckpt_dir, ckpt_file)\n", + "\n", + " save_checkpoint(self.main_module.main_model, ckpt_path)\n", + " self.ckpt_history.append(ckpt_path)\n", + "\n", + " if len(self.ckpt_history) > self.keep_ckpt_max:\n", + " removed_ckpt = self.ckpt_history.pop(0)\n", + " os.remove(removed_ckpt)\n", + " self.logger.info(f\"Removed outdated ckpt: {removed_ckpt}\")\n", + "\n", + "\n", + "class Trainonestepforward(nn.Cell):\n", + " \"\"\"A neural network cell that performs one training step forward pass for a diffusion model.\n", + " This class encapsulates the forward pass computation for training a diffusion model,\n", + " handling the input processing, latent space encoding, conditioning, and loss calculation.\n", + " Args:\n", + " model (nn.Cell): The main diffusion model containing the necessary submodules\n", + " for encoding, conditioning, and loss computation.\n", + " \"\"\"\n", + "\n", + " def __init__(self, model):\n", + " super().__init__()\n", + " self.main_module = model\n", + "\n", + " def construct(self, inputs):\n", + " \"\"\"Perform one forward training step and compute the loss.\"\"\"\n", + " x, condition = self.main_module.get_input(inputs)\n", + " x = x.transpose(0, 1, 4, 2, 3)\n", + " n, t_, c_, h, w = x.shape\n", + " x = x.reshape(n * t_, c_, h, w)\n", + " z = self.main_module.encode_first_stage(x)\n", + " _, c_z, h_z, w_z = z.shape\n", + " z = z.reshape(n, -1, c_z, h_z, w_z)\n", + " z = z.transpose(0, 1, 3, 4, 2)\n", + " t = ops.randint(0, self.main_module.num_timesteps, (n,)).long()\n", + " zc = self.main_module.cond_stage_forward(condition)\n", + " loss = self.main_module.p_losses(z, zc, t, noise=None)\n", + " return loss" + ] + }, + { + "cell_type": "markdown", + "id": "da1c23d7-fbff-4492-af19-79f30bcc0185", + "metadata": {}, + "source": [ + "## 模型训练\n", + "\n", + "在本教程中,我们使用DiffusionTrainer对模型进行训练" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "024a0222-c7c7-4cd0-9b6f-7350b92619af", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2025-04-07 10:32:36,351 - 4106154625.py[line:46] - INFO: Initializing training process...\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "........." + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2025-04-07 10:34:09,378 - 4106154625.py[line:64] - INFO: epoch: 0 step: 0, loss: 1.0008465\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "." + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2025-04-07 10:34:16,871 - 4106154625.py[line:64] - INFO: epoch: 0 step: 1, loss: 1.0023363\n", + "2025-04-07 10:34:18,724 - 4106154625.py[line:64] - INFO: epoch: 0 step: 2, loss: 1.0009086\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "." + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2025-04-07 10:34:20,513 - 4106154625.py[line:64] - INFO: epoch: 0 step: 3, loss: 0.99787366\n", + "2025-04-07 10:34:22,280 - 4106154625.py[line:64] - INFO: epoch: 0 step: 4, loss: 0.9979043\n", + "2025-04-07 10:34:24,072 - 4106154625.py[line:64] - INFO: epoch: 0 step: 5, loss: 0.99897844\n", + "2025-04-07 10:34:25,864 - 4106154625.py[line:64] - INFO: epoch: 0 step: 6, loss: 1.0021904\n", + "2025-04-07 10:34:27,709 - 4106154625.py[line:64] - INFO: epoch: 0 step: 7, loss: 0.9984627\n", + "2025-04-07 10:34:29,578 - 4106154625.py[line:64] - INFO: epoch: 0 step: 8, loss: 0.9952746\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "." + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2025-04-07 10:34:31,432 - 4106154625.py[line:64] - INFO: epoch: 0 step: 9, loss: 1.0003254\n", + "2025-04-07 10:34:33,402 - 4106154625.py[line:64] - INFO: epoch: 0 step: 10, loss: 1.0020428\n", + "2025-04-07 10:34:35,218 - 4106154625.py[line:64] - INFO: epoch: 0 step: 11, loss: 0.99563503\n", + "2025-04-07 10:34:37,149 - 4106154625.py[line:64] - INFO: epoch: 0 step: 12, loss: 0.99336195\n", + "2025-04-07 10:34:38,949 - 4106154625.py[line:64] - INFO: epoch: 0 step: 13, loss: 1.0023757\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "." + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2025-04-07 10:34:40,962 - 4106154625.py[line:64] - INFO: epoch: 0 step: 14, loss: 1.0007098\n", + "2025-04-07 10:34:43,332 - 4106154625.py[line:64] - INFO: epoch: 0 step: 15, loss: 0.99492\n", + "2025-04-07 10:34:45,177 - 4106154625.py[line:64] - INFO: epoch: 0 step: 16, loss: 0.99957407\n", + "2025-04-07 10:34:47,040 - 4106154625.py[line:64] - INFO: epoch: 0 step: 17, loss: 0.99685913\n", + "2025-04-07 10:34:48,823 - 4106154625.py[line:64] - INFO: epoch: 0 step: 18, loss: 0.9956614\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "." + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2025-04-07 10:34:50,720 - 4106154625.py[line:64] - INFO: epoch: 0 step: 19, loss: 0.9934994\n", + "2025-04-07 10:34:52,552 - 4106154625.py[line:64] - INFO: epoch: 0 step: 20, loss: 0.99108785\n", + "2025-04-07 10:34:54,389 - 4106154625.py[line:64] - INFO: epoch: 0 step: 21, loss: 0.99182785\n", + "2025-04-07 10:34:56,159 - 4106154625.py[line:64] - INFO: epoch: 0 step: 22, loss: 0.99136275\n", + "2025-04-07 10:34:58,118 - 4106154625.py[line:64] - INFO: epoch: 0 step: 23, loss: 0.9886243\n", + "2025-04-07 10:35:00,045 - 4106154625.py[line:64] - INFO: epoch: 0 step: 24, loss: 0.9947286\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "." + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2025-04-07 10:35:01,964 - 4106154625.py[line:64] - INFO: epoch: 0 step: 25, loss: 0.99265075\n", + "2025-04-07 10:35:03,818 - 4106154625.py[line:64] - INFO: epoch: 0 step: 26, loss: 0.98734057\n", + "2025-04-07 10:35:05,604 - 4106154625.py[line:64] - INFO: epoch: 0 step: 27, loss: 0.9867786\n", + "2025-04-07 10:35:07,383 - 4106154625.py[line:64] - INFO: epoch: 0 step: 28, loss: 0.98637533\n", + "2025-04-07 10:35:09,311 - 4106154625.py[line:64] - INFO: epoch: 0 step: 29, loss: 0.98799324\n", + "2025-04-07 10:35:11,054 - 4106154625.py[line:64] - INFO: epoch: 0 step: 30, loss: 0.9851307\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "." + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2025-04-07 10:35:12,883 - 4106154625.py[line:64] - INFO: epoch: 0 step: 31, loss: 0.98547524\n", + "2025-04-07 10:35:14,629 - 4106154625.py[line:64] - INFO: epoch: 0 step: 32, loss: 0.9783558\n", + "2025-04-07 10:35:16,444 - 4106154625.py[line:64] - INFO: epoch: 0 step: 33, loss: 0.9851396\n", + "2025-04-07 10:35:18,122 - 4106154625.py[line:64] - INFO: epoch: 0 step: 34, loss: 0.98461366\n", + "2025-04-07 10:35:20,102 - 4106154625.py[line:64] - INFO: epoch: 0 step: 35, loss: 0.9879103\n", + "2025-04-07 10:35:22,232 - 4106154625.py[line:64] - INFO: epoch: 0 step: 36, loss: 0.9743713\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "." + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2025-04-07 10:35:24,417 - 4106154625.py[line:64] - INFO: epoch: 0 step: 37, loss: 0.98045284\n", + "2025-04-07 10:35:26,435 - 4106154625.py[line:64] - INFO: epoch: 0 step: 38, loss: 0.97129095\n", + "2025-04-07 10:35:28,351 - 4106154625.py[line:64] - INFO: epoch: 0 step: 39, loss: 0.98204684\n", + "2025-04-07 10:35:30,122 - 4106154625.py[line:64] - INFO: epoch: 0 step: 40, loss: 0.97880834\n", + "2025-04-07 10:35:31,760 - 4106154625.py[line:64] - INFO: epoch: 0 step: 41, loss: 0.96932787\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "." + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2025-04-07 10:35:33,513 - 4106154625.py[line:64] - INFO: epoch: 0 step: 42, loss: 0.9717276\n", + "2025-04-07 10:35:35,276 - 4106154625.py[line:64] - INFO: epoch: 0 step: 43, loss: 0.9716038\n", + "2025-04-07 10:35:37,238 - 4106154625.py[line:64] - INFO: epoch: 0 step: 44, loss: 0.9686392\n", + "2025-04-07 10:35:39,268 - 4106154625.py[line:64] - INFO: epoch: 0 step: 45, loss: 0.99201906\n", + "2025-04-07 10:35:41,141 - 4106154625.py[line:64] - INFO: epoch: 0 step: 46, loss: 0.977281\n", + "2025-04-07 10:35:43,166 - 4106154625.py[line:64] - INFO: epoch: 0 step: 47, loss: 0.96613944\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "." + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2025-04-07 10:35:45,249 - 4106154625.py[line:64] - INFO: epoch: 0 step: 48, loss: 0.9612762\n", + "2025-04-07 10:35:47,142 - 4106154625.py[line:64] - INFO: epoch: 0 step: 49, loss: 0.9577536\n", + "2025-04-07 10:35:49,114 - 4106154625.py[line:64] - INFO: epoch: 0 step: 50, loss: 0.95175207\n", + "2025-04-07 10:35:51,080 - 4106154625.py[line:64] - INFO: epoch: 0 step: 51, loss: 0.95729643\n", + "2025-04-07 10:35:53,116 - 4106154625.py[line:64] - INFO: epoch: 0 step: 52, loss: 0.960687\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "." + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2025-04-07 10:35:55,202 - 4106154625.py[line:64] - INFO: epoch: 0 step: 53, loss: 0.9575224\n", + "2025-04-07 10:35:57,168 - 4106154625.py[line:64] - INFO: epoch: 0 step: 54, loss: 0.9500365\n", + "2025-04-07 10:35:59,015 - 4106154625.py[line:64] - INFO: epoch: 0 step: 55, loss: 0.94735086\n", + "2025-04-07 10:36:01,016 - 4106154625.py[line:64] - INFO: epoch: 0 step: 56, loss: 0.97874105\n", + "2025-04-07 10:36:02,904 - 4106154625.py[line:64] - INFO: epoch: 0 step: 57, loss: 0.9451903\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "." + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2025-04-07 10:36:04,717 - 4106154625.py[line:64] - INFO: epoch: 0 step: 58, loss: 0.94447565\n", + "2025-04-07 10:36:06,499 - 4106154625.py[line:64] - INFO: epoch: 0 step: 59, loss: 0.94874763\n", + "2025-04-07 10:36:08,260 - 4106154625.py[line:64] - INFO: epoch: 0 step: 60, loss: 0.9672854\n", + "2025-04-07 10:36:10,146 - 4106154625.py[line:64] - INFO: epoch: 0 step: 61, loss: 0.9565505\n", + "2025-04-07 10:36:12,112 - 4106154625.py[line:64] - INFO: epoch: 0 step: 62, loss: 0.9480209\n", + "2025-04-07 10:36:13,989 - 4106154625.py[line:64] - INFO: epoch: 0 step: 63, loss: 0.94844496\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "." + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2025-04-07 10:36:15,759 - 4106154625.py[line:64] - INFO: epoch: 0 step: 64, loss: 0.94463414\n", + "2025-04-07 10:36:17,409 - 4106154625.py[line:64] - INFO: epoch: 0 step: 65, loss: 0.9484377\n", + "2025-04-07 10:36:19,103 - 4106154625.py[line:64] - INFO: epoch: 0 step: 66, loss: 0.93955624\n", + "2025-04-07 10:36:21,005 - 4106154625.py[line:64] - INFO: epoch: 0 step: 67, loss: 0.9357619\n", + "2025-04-07 10:36:22,738 - 4106154625.py[line:64] - INFO: epoch: 0 step: 68, loss: 0.9534744\n", + "2025-04-07 10:36:24,626 - 4106154625.py[line:64] - INFO: epoch: 0 step: 69, loss: 0.970679\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "." + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2025-04-07 10:36:26,527 - 4106154625.py[line:64] - INFO: epoch: 0 step: 70, loss: 0.9313204\n", + "2025-04-07 10:36:28,335 - 4106154625.py[line:64] - INFO: epoch: 0 step: 71, loss: 0.927449\n", + "2025-04-07 10:36:30,082 - 4106154625.py[line:64] - INFO: epoch: 0 step: 72, loss: 0.9536683\n", + "2025-04-07 10:36:31,761 - 4106154625.py[line:64] - INFO: epoch: 0 step: 73, loss: 0.92975646\n", + "2025-04-07 10:36:33,780 - 4106154625.py[line:64] - INFO: epoch: 0 step: 74, loss: 0.9387269\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "." + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2025-04-07 10:36:35,917 - 4106154625.py[line:64] - INFO: epoch: 0 step: 75, loss: 0.9491191\n", + "2025-04-07 10:36:37,922 - 4106154625.py[line:64] - INFO: epoch: 0 step: 76, loss: 0.9263407\n", + "2025-04-07 10:36:39,572 - 4106154625.py[line:64] - INFO: epoch: 0 step: 77, loss: 0.95135903\n", + "2025-04-07 10:36:41,209 - 4106154625.py[line:64] - INFO: epoch: 0 step: 78, loss: 0.92555064\n", + "2025-04-07 10:36:42,827 - 4106154625.py[line:64] - INFO: epoch: 0 step: 79, loss: 0.93047976\n", + "2025-04-07 10:36:44,649 - 4106154625.py[line:64] - INFO: epoch: 0 step: 80, loss: 0.9445814\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "." + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2025-04-07 10:36:46,501 - 4106154625.py[line:64] - INFO: epoch: 0 step: 81, loss: 0.92167306\n", + "2025-04-07 10:36:48,449 - 4106154625.py[line:64] - INFO: epoch: 0 step: 82, loss: 0.9199027\n", + "2025-04-07 10:36:50,603 - 4106154625.py[line:64] - INFO: epoch: 0 step: 83, loss: 0.95979875\n", + "2025-04-07 10:36:52,662 - 4106154625.py[line:64] - INFO: epoch: 0 step: 84, loss: 0.94403404\n", + "2025-04-07 10:36:54,314 - 4106154625.py[line:64] - INFO: epoch: 0 step: 85, loss: 0.91954345\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "." + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2025-04-07 10:36:55,996 - 4106154625.py[line:64] - INFO: epoch: 0 step: 86, loss: 0.92873365\n", + "2025-04-07 10:36:57,701 - 4106154625.py[line:64] - INFO: epoch: 0 step: 87, loss: 0.91166925\n", + "2025-04-07 10:36:59,362 - 4106154625.py[line:64] - INFO: epoch: 0 step: 88, loss: 0.92743254\n", + "2025-04-07 10:37:01,139 - 4106154625.py[line:64] - INFO: epoch: 0 step: 89, loss: 0.9097767\n", + "2025-04-07 10:37:03,120 - 4106154625.py[line:64] - INFO: epoch: 0 step: 90, loss: 0.918455\n", + "2025-04-07 10:37:05,260 - 4106154625.py[line:64] - INFO: epoch: 0 step: 91, loss: 0.9123219\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "." + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2025-04-07 10:37:06,972 - 4106154625.py[line:64] - INFO: epoch: 0 step: 92, loss: 0.9185343\n", + "2025-04-07 10:37:08,881 - 4106154625.py[line:64] - INFO: epoch: 0 step: 93, loss: 0.9153005\n", + "2025-04-07 10:37:10,973 - 4106154625.py[line:64] - INFO: epoch: 0 step: 94, loss: 0.90332276\n", + "2025-04-07 10:37:13,070 - 4106154625.py[line:64] - INFO: epoch: 0 step: 95, loss: 0.90544885\n", + "2025-04-07 10:37:14,777 - 4106154625.py[line:64] - INFO: epoch: 0 step: 96, loss: 0.92892224\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "." + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2025-04-07 10:37:16,919 - 4106154625.py[line:64] - INFO: epoch: 0 step: 97, loss: 0.92682004\n", + "2025-04-07 10:37:18,923 - 4106154625.py[line:64] - INFO: epoch: 0 step: 98, loss: 0.9004317\n", + "2025-04-07 10:37:20,940 - 4106154625.py[line:64] - INFO: epoch: 0 step: 99, loss: 0.908974\n", + "2025-04-07 10:37:22,739 - 4106154625.py[line:64] - INFO: epoch: 0 step: 100, loss: 0.8956867\n", + "2025-04-07 10:37:24,509 - 4106154625.py[line:64] - INFO: epoch: 0 step: 101, loss: 0.8987319\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "." + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2025-04-07 10:37:26,159 - 4106154625.py[line:64] - INFO: epoch: 0 step: 102, loss: 0.9083508\n", + "2025-04-07 10:37:27,783 - 4106154625.py[line:64] - INFO: epoch: 0 step: 103, loss: 0.89505464\n", + "2025-04-07 10:37:29,432 - 4106154625.py[line:64] - INFO: epoch: 0 step: 104, loss: 0.9006442\n", + "2025-04-07 10:37:31,031 - 4106154625.py[line:64] - INFO: epoch: 0 step: 105, loss: 0.8925739\n", + "2025-04-07 10:37:32,688 - 4106154625.py[line:64] - INFO: epoch: 0 step: 106, loss: 0.8919925\n", + "2025-04-07 10:37:34,278 - 4106154625.py[line:64] - INFO: epoch: 0 step: 107, loss: 0.8901893\n", + "2025-04-07 10:37:35,874 - 4106154625.py[line:64] - INFO: epoch: 0 step: 108, loss: 0.8947307\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "." + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2025-04-07 10:37:37,562 - 4106154625.py[line:64] - INFO: epoch: 0 step: 109, loss: 0.89940923\n", + "2025-04-07 10:37:39,124 - 4106154625.py[line:64] - INFO: epoch: 0 step: 110, loss: 0.88965017\n", + "2025-04-07 10:37:40,773 - 4106154625.py[line:64] - INFO: epoch: 0 step: 111, loss: 0.8835504\n", + "2025-04-07 10:37:42,345 - 4106154625.py[line:64] - INFO: epoch: 0 step: 112, loss: 0.8785033\n", + "2025-04-07 10:37:43,921 - 4106154625.py[line:64] - INFO: epoch: 0 step: 113, loss: 0.8814548\n", + "2025-04-07 10:37:45,600 - 4106154625.py[line:64] - INFO: epoch: 0 step: 114, loss: 0.8877945\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "." + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2025-04-07 10:37:47,338 - 4106154625.py[line:64] - INFO: epoch: 0 step: 115, loss: 0.88197625\n", + "2025-04-07 10:37:48,996 - 4106154625.py[line:64] - INFO: epoch: 0 step: 116, loss: 0.8941308\n", + "2025-04-07 10:37:50,679 - 4106154625.py[line:64] - INFO: epoch: 0 step: 117, loss: 0.88495713\n", + "2025-04-07 10:37:52,603 - 4106154625.py[line:64] - INFO: epoch: 0 step: 118, loss: 0.90219486\n", + "2025-04-07 10:37:54,497 - 4106154625.py[line:64] - INFO: epoch: 0 step: 119, loss: 0.89262724\n", + "2025-04-07 10:37:56,103 - 4106154625.py[line:64] - INFO: epoch: 0 step: 120, loss: 0.8879415\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "." + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2025-04-07 10:37:57,735 - 4106154625.py[line:64] - INFO: epoch: 0 step: 121, loss: 0.878676\n", + "2025-04-07 10:37:59,364 - 4106154625.py[line:64] - INFO: epoch: 0 step: 122, loss: 0.8715365\n", + "2025-04-07 10:38:00,946 - 4106154625.py[line:64] - INFO: epoch: 0 step: 123, loss: 0.8677654\n", + "2025-04-07 10:38:02,558 - 4106154625.py[line:64] - INFO: epoch: 0 step: 124, loss: 0.8684499\n", + "2025-04-07 10:38:04,199 - 4106154625.py[line:64] - INFO: epoch: 0 step: 125, loss: 0.8848672\n", + "2025-04-07 10:38:05,816 - 4106154625.py[line:64] - INFO: epoch: 0 step: 126, loss: 0.8611082\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "." + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2025-04-07 10:38:07,435 - 4106154625.py[line:64] - INFO: epoch: 0 step: 127, loss: 0.87677616\n", + "2025-04-07 10:38:09,051 - 4106154625.py[line:64] - INFO: epoch: 0 step: 128, loss: 0.8892087\n", + "2025-04-07 10:38:10,675 - 4106154625.py[line:64] - INFO: epoch: 0 step: 129, loss: 0.87242335\n", + "2025-04-07 10:38:12,362 - 4106154625.py[line:64] - INFO: epoch: 0 step: 130, loss: 0.86540776\n", + "2025-04-07 10:38:13,976 - 4106154625.py[line:64] - INFO: epoch: 0 step: 131, loss: 0.9510796\n", + "2025-04-07 10:38:15,605 - 4106154625.py[line:64] - INFO: epoch: 0 step: 132, loss: 0.8619976\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "." + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2025-04-07 10:38:17,224 - 4106154625.py[line:64] - INFO: epoch: 0 step: 133, loss: 0.8630925\n", + "2025-04-07 10:38:18,780 - 4106154625.py[line:64] - INFO: epoch: 0 step: 134, loss: 0.85540855\n", + "2025-04-07 10:38:20,350 - 4106154625.py[line:64] - INFO: epoch: 0 step: 135, loss: 0.85183513\n", + "2025-04-07 10:38:21,884 - 4106154625.py[line:64] - INFO: epoch: 0 step: 136, loss: 0.8917813\n", + "2025-04-07 10:38:23,435 - 4106154625.py[line:64] - INFO: epoch: 0 step: 137, loss: 0.8526528\n", + "2025-04-07 10:38:24,950 - 4106154625.py[line:64] - INFO: epoch: 0 step: 138, loss: 0.8536273\n", + "2025-04-07 10:38:26,598 - 4106154625.py[line:64] - INFO: epoch: 0 step: 139, loss: 0.8565655\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "." + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2025-04-07 10:38:28,139 - 4106154625.py[line:64] - INFO: epoch: 0 step: 140, loss: 0.8921677\n", + "2025-04-07 10:38:29,688 - 4106154625.py[line:64] - INFO: epoch: 0 step: 141, loss: 0.86149573\n", + "2025-04-07 10:38:31,311 - 4106154625.py[line:64] - INFO: epoch: 0 step: 142, loss: 0.8502701\n", + "2025-04-07 10:38:32,945 - 4106154625.py[line:64] - INFO: epoch: 0 step: 143, loss: 0.84761256\n", + "2025-04-07 10:38:34,574 - 4106154625.py[line:64] - INFO: epoch: 0 step: 144, loss: 0.8530063\n", + "2025-04-07 10:38:36,196 - 4106154625.py[line:64] - INFO: epoch: 0 step: 145, loss: 0.89813197\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "." + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2025-04-07 10:38:37,836 - 4106154625.py[line:64] - INFO: epoch: 0 step: 146, loss: 0.86497414\n", + "2025-04-07 10:38:39,461 - 4106154625.py[line:64] - INFO: epoch: 0 step: 147, loss: 0.86043245\n", + "2025-04-07 10:38:41,038 - 4106154625.py[line:64] - INFO: epoch: 0 step: 148, loss: 0.8537921\n", + "2025-04-07 10:38:42,593 - 4106154625.py[line:64] - INFO: epoch: 0 step: 149, loss: 0.84643245\n", + "2025-04-07 10:38:44,350 - 4106154625.py[line:64] - INFO: epoch: 0 step: 150, loss: 0.84086126\n", + "2025-04-07 10:38:45,982 - 4106154625.py[line:64] - INFO: epoch: 0 step: 151, loss: 0.8376725\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "." + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2025-04-07 10:38:47,621 - 4106154625.py[line:64] - INFO: epoch: 0 step: 152, loss: 0.8443006\n", + "2025-04-07 10:38:49,414 - 4106154625.py[line:64] - INFO: epoch: 0 step: 153, loss: 0.87024367\n", + "2025-04-07 10:38:51,379 - 4106154625.py[line:64] - INFO: epoch: 0 step: 154, loss: 0.8439486\n", + "2025-04-07 10:38:53,492 - 4106154625.py[line:64] - INFO: epoch: 0 step: 155, loss: 0.8428738\n", + "2025-04-07 10:38:55,505 - 4106154625.py[line:64] - INFO: epoch: 0 step: 156, loss: 0.8446244\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "." + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2025-04-07 10:38:57,390 - 4106154625.py[line:64] - INFO: epoch: 0 step: 157, loss: 0.82819533\n", + "2025-04-07 10:38:59,154 - 4106154625.py[line:64] - INFO: epoch: 0 step: 158, loss: 0.8346045\n", + "2025-04-07 10:39:00,861 - 4106154625.py[line:64] - INFO: epoch: 0 step: 159, loss: 0.91556245\n", + "2025-04-07 10:39:02,460 - 4106154625.py[line:64] - INFO: epoch: 0 step: 160, loss: 0.8365531\n", + "2025-04-07 10:39:03,994 - 4106154625.py[line:64] - INFO: epoch: 0 step: 161, loss: 0.82283574\n", + "2025-04-07 10:39:05,550 - 4106154625.py[line:64] - INFO: epoch: 0 step: 162, loss: 0.83937204\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "." + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2025-04-07 10:39:07,130 - 4106154625.py[line:64] - INFO: epoch: 0 step: 163, loss: 0.82220745\n", + "2025-04-07 10:39:08,702 - 4106154625.py[line:64] - INFO: epoch: 0 step: 164, loss: 0.8206043\n", + "2025-04-07 10:39:10,286 - 4106154625.py[line:64] - INFO: epoch: 0 step: 165, loss: 0.82163304\n", + "2025-04-07 10:39:11,858 - 4106154625.py[line:64] - INFO: epoch: 0 step: 166, loss: 0.9156118\n", + "2025-04-07 10:39:13,664 - 4106154625.py[line:64] - INFO: epoch: 0 step: 167, loss: 0.8271665\n", + "2025-04-07 10:39:15,392 - 4106154625.py[line:64] - INFO: epoch: 0 step: 168, loss: 0.8538544\n", + "2025-04-07 10:39:17,448 - 4106154625.py[line:64] - INFO: epoch: 0 step: 169, loss: 0.81377554\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "." + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2025-04-07 10:39:19,382 - 4106154625.py[line:64] - INFO: epoch: 0 step: 170, loss: 0.82164574\n", + "2025-04-07 10:39:21,451 - 4106154625.py[line:64] - INFO: epoch: 0 step: 171, loss: 0.8611313\n", + "2025-04-07 10:39:23,232 - 4106154625.py[line:64] - INFO: epoch: 0 step: 172, loss: 0.910937\n", + "2025-04-07 10:39:25,157 - 4106154625.py[line:64] - INFO: epoch: 0 step: 173, loss: 0.81960344\n", + "2025-04-07 10:39:27,027 - 4106154625.py[line:64] - INFO: epoch: 0 step: 174, loss: 0.8318243\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "." + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2025-04-07 10:39:28,879 - 4106154625.py[line:64] - INFO: epoch: 0 step: 175, loss: 0.8163141\n", + "2025-04-07 10:39:30,569 - 4106154625.py[line:64] - INFO: epoch: 0 step: 176, loss: 0.81251186\n", + "2025-04-07 10:39:32,357 - 4106154625.py[line:64] - INFO: epoch: 0 step: 177, loss: 0.8562678\n", + "2025-04-07 10:39:34,015 - 4106154625.py[line:64] - INFO: epoch: 0 step: 178, loss: 0.815516\n", + "2025-04-07 10:39:35,701 - 4106154625.py[line:64] - INFO: epoch: 0 step: 179, loss: 0.8176594\n", + "2025-04-07 10:39:37,351 - 4106154625.py[line:64] - INFO: epoch: 0 step: 180, loss: 0.81118274\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "." + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2025-04-07 10:39:38,946 - 4106154625.py[line:64] - INFO: epoch: 0 step: 181, loss: 0.80203724\n", + "2025-04-07 10:39:40,642 - 4106154625.py[line:64] - INFO: epoch: 0 step: 182, loss: 0.87345916\n", + "2025-04-07 10:39:42,321 - 4106154625.py[line:64] - INFO: epoch: 0 step: 183, loss: 0.81266487\n", + "2025-04-07 10:39:43,999 - 4106154625.py[line:64] - INFO: epoch: 0 step: 184, loss: 0.80216926\n", + "2025-04-07 10:39:45,764 - 4106154625.py[line:64] - INFO: epoch: 0 step: 185, loss: 0.80834883\n", + "2025-04-07 10:39:47,643 - 4106154625.py[line:64] - INFO: epoch: 0 step: 186, loss: 0.8091302\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "." + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2025-04-07 10:39:49,513 - 4106154625.py[line:64] - INFO: epoch: 0 step: 187, loss: 0.85867965\n", + "2025-04-07 10:39:51,513 - 4106154625.py[line:64] - INFO: epoch: 0 step: 188, loss: 0.83379465\n", + "2025-04-07 10:39:53,201 - 4106154625.py[line:64] - INFO: epoch: 0 step: 189, loss: 0.8088391\n", + "2025-04-07 10:39:54,998 - 4106154625.py[line:64] - INFO: epoch: 0 step: 190, loss: 0.80790806\n", + "2025-04-07 10:39:57,158 - 4106154625.py[line:64] - INFO: epoch: 0 step: 191, loss: 0.8407364\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "." + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2025-04-07 10:39:59,182 - 4106154625.py[line:64] - INFO: epoch: 0 step: 192, loss: 0.8151839\n", + "2025-04-07 10:40:00,872 - 4106154625.py[line:64] - INFO: epoch: 0 step: 193, loss: 0.78970444\n", + "2025-04-07 10:40:02,929 - 4106154625.py[line:64] - INFO: epoch: 0 step: 194, loss: 0.79682875\n", + "2025-04-07 10:40:04,755 - 4106154625.py[line:64] - INFO: epoch: 0 step: 195, loss: 0.82242036\n", + "2025-04-07 10:40:06,438 - 4106154625.py[line:64] - INFO: epoch: 0 step: 196, loss: 0.7956406\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "." + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2025-04-07 10:40:08,369 - 4106154625.py[line:64] - INFO: epoch: 0 step: 197, loss: 0.8161787\n", + "2025-04-07 10:40:10,191 - 4106154625.py[line:64] - INFO: epoch: 0 step: 198, loss: 0.8084446\n", + "2025-04-07 10:40:11,973 - 4106154625.py[line:64] - INFO: epoch: 0 step: 199, loss: 0.8210702\n", + "2025-04-07 10:40:13,663 - 4106154625.py[line:64] - INFO: epoch: 0 step: 200, loss: 0.80087566\n", + "2025-04-07 10:40:15,493 - 4106154625.py[line:64] - INFO: epoch: 0 step: 201, loss: 0.87920845\n", + "2025-04-07 10:40:17,323 - 4106154625.py[line:64] - INFO: epoch: 0 step: 202, loss: 0.8160571\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "." + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2025-04-07 10:40:19,189 - 4106154625.py[line:64] - INFO: epoch: 0 step: 203, loss: 0.7799623\n", + "2025-04-07 10:40:21,020 - 4106154625.py[line:64] - INFO: epoch: 0 step: 204, loss: 0.81907594\n", + "2025-04-07 10:40:22,823 - 4106154625.py[line:64] - INFO: epoch: 0 step: 205, loss: 0.78082323\n", + "2025-04-07 10:40:24,593 - 4106154625.py[line:64] - INFO: epoch: 0 step: 206, loss: 0.7767377\n", + "2025-04-07 10:40:26,411 - 4106154625.py[line:64] - INFO: epoch: 0 step: 207, loss: 0.78217006\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "." + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2025-04-07 10:40:28,204 - 4106154625.py[line:64] - INFO: epoch: 0 step: 208, loss: 0.78541696\n", + "2025-04-07 10:40:30,055 - 4106154625.py[line:64] - INFO: epoch: 0 step: 209, loss: 0.788193\n", + "2025-04-07 10:40:31,905 - 4106154625.py[line:64] - INFO: epoch: 0 step: 210, loss: 0.77395964\n", + "2025-04-07 10:40:33,954 - 4106154625.py[line:64] - INFO: epoch: 0 step: 211, loss: 0.7963271\n", + "2025-04-07 10:40:35,947 - 4106154625.py[line:64] - INFO: epoch: 0 step: 212, loss: 0.77294105\n", + "2025-04-07 10:40:37,721 - 4106154625.py[line:64] - INFO: epoch: 0 step: 213, loss: 0.7669926\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "." + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2025-04-07 10:40:39,729 - 4106154625.py[line:64] - INFO: epoch: 0 step: 214, loss: 0.79589576\n", + "2025-04-07 10:40:41,758 - 4106154625.py[line:64] - INFO: epoch: 0 step: 215, loss: 0.7651855\n", + "2025-04-07 10:40:43,662 - 4106154625.py[line:64] - INFO: epoch: 0 step: 216, loss: 0.820046\n", + "2025-04-07 10:40:45,532 - 4106154625.py[line:64] - INFO: epoch: 0 step: 217, loss: 0.7689292\n", + "2025-04-07 10:40:47,505 - 4106154625.py[line:64] - INFO: epoch: 0 step: 218, loss: 0.81641614\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "." + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2025-04-07 10:40:49,338 - 4106154625.py[line:64] - INFO: epoch: 0 step: 219, loss: 0.76227266\n", + "2025-04-07 10:40:51,284 - 4106154625.py[line:64] - INFO: epoch: 0 step: 220, loss: 0.85349905\n", + "2025-04-07 10:40:53,122 - 4106154625.py[line:64] - INFO: epoch: 0 step: 221, loss: 0.8078137\n", + "2025-04-07 10:40:54,912 - 4106154625.py[line:64] - INFO: epoch: 0 step: 222, loss: 0.7646342\n", + "2025-04-07 10:40:56,772 - 4106154625.py[line:64] - INFO: epoch: 0 step: 223, loss: 0.7557045\n", + "2025-04-07 10:40:58,621 - 4106154625.py[line:64] - INFO: epoch: 0 step: 224, loss: 0.76513314\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "." + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2025-04-07 10:41:00,458 - 4106154625.py[line:64] - INFO: epoch: 0 step: 225, loss: 0.7822351\n", + "2025-04-07 10:41:02,231 - 4106154625.py[line:64] - INFO: epoch: 0 step: 226, loss: 0.7729878\n", + "2025-04-07 10:41:04,074 - 4106154625.py[line:64] - INFO: epoch: 0 step: 227, loss: 0.75777054\n", + "2025-04-07 10:41:05,926 - 4106154625.py[line:64] - INFO: epoch: 0 step: 228, loss: 0.7532151\n", + "2025-04-07 10:41:07,785 - 4106154625.py[line:64] - INFO: epoch: 0 step: 229, loss: 0.795061\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "." + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2025-04-07 10:41:09,631 - 4106154625.py[line:64] - INFO: epoch: 0 step: 230, loss: 0.7710381\n", + "2025-04-07 10:41:11,459 - 4106154625.py[line:64] - INFO: epoch: 0 step: 231, loss: 0.7682188\n", + "2025-04-07 10:41:13,288 - 4106154625.py[line:64] - INFO: epoch: 0 step: 232, loss: 0.7783369\n", + "2025-04-07 10:41:15,137 - 4106154625.py[line:64] - INFO: epoch: 0 step: 233, loss: 0.7680697\n", + "2025-04-07 10:41:17,048 - 4106154625.py[line:64] - INFO: epoch: 0 step: 234, loss: 0.75664115\n", + "2025-04-07 10:41:18,831 - 4106154625.py[line:64] - INFO: epoch: 0 step: 235, loss: 0.7511877\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "." + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2025-04-07 10:41:20,764 - 4106154625.py[line:64] - INFO: epoch: 0 step: 236, loss: 0.7427261\n", + "2025-04-07 10:41:22,569 - 4106154625.py[line:64] - INFO: epoch: 0 step: 237, loss: 0.8036304\n", + "2025-04-07 10:41:24,487 - 4106154625.py[line:64] - INFO: epoch: 0 step: 238, loss: 0.76217574\n", + "2025-04-07 10:41:26,373 - 4106154625.py[line:64] - INFO: epoch: 0 step: 239, loss: 0.7397079\n", + "2025-04-07 10:41:28,139 - 4106154625.py[line:64] - INFO: epoch: 0 step: 240, loss: 0.8942822\n", + "2025-04-07 10:41:30,037 - 4106154625.py[line:64] - INFO: epoch: 0 step: 241, loss: 0.74506545\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "." + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2025-04-07 10:41:32,130 - 4106154625.py[line:64] - INFO: epoch: 0 step: 242, loss: 0.7901791\n", + "2025-04-07 10:41:34,077 - 4106154625.py[line:64] - INFO: epoch: 0 step: 243, loss: 0.74124205\n", + "2025-04-07 10:41:35,979 - 4106154625.py[line:64] - INFO: epoch: 0 step: 244, loss: 0.7894727\n", + "2025-04-07 10:41:37,959 - 4106154625.py[line:64] - INFO: epoch: 0 step: 245, loss: 0.83756655\n", + "2025-04-07 10:41:39,831 - 4106154625.py[line:64] - INFO: epoch: 0 step: 246, loss: 0.7398231\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "." + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2025-04-07 10:41:41,763 - 4106154625.py[line:64] - INFO: epoch: 0 step: 247, loss: 0.76385504\n", + "2025-04-07 10:41:43,700 - 4106154625.py[line:64] - INFO: epoch: 0 step: 248, loss: 0.7347469\n", + "2025-04-07 10:41:45,518 - 4106154625.py[line:64] - INFO: epoch: 0 step: 249, loss: 0.8313259\n", + "2025-04-07 10:41:47,373 - 4106154625.py[line:64] - INFO: epoch: 0 step: 250, loss: 0.8136975\n", + "2025-04-07 10:41:49,420 - 4106154625.py[line:64] - INFO: epoch: 0 step: 251, loss: 0.7310439\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "." + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "......" + ] + }, + + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2025-04-07 13:39:55,859 - 4106154625.py[line:64] - INFO: epoch: 4 step: 1247, loss: 0.021378823\n", + "2025-04-07 13:39:57,754 - 4106154625.py[line:64] - INFO: epoch: 4 step: 1248, loss: 0.01565772\n", + "2025-04-07 13:39:59,606 - 4106154625.py[line:64] - INFO: epoch: 4 step: 1249, loss: 0.012067624\n", + "2025-04-07 13:40:01,396 - 4106154625.py[line:64] - INFO: epoch: 4 step: 1250, loss: 0.017700804\n", + "2025-04-07 13:40:03,181 - 4106154625.py[line:64] - INFO: epoch: 4 step: 1251, loss: 0.06254268\n", + "2025-04-07 13:40:04,945 - 4106154625.py[line:64] - INFO: epoch: 4 step: 1252, loss: 0.013293369\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "." + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2025-04-07 13:40:06,770 - 4106154625.py[line:64] - INFO: epoch: 4 step: 1253, loss: 0.026906993\n", + "2025-04-07 13:40:08,644 - 4106154625.py[line:64] - INFO: epoch: 4 step: 1254, loss: 0.18210539\n", + "2025-04-07 13:40:10,593 - 4106154625.py[line:64] - INFO: epoch: 4 step: 1255, loss: 0.024170894\n", + "2025-04-07 13:40:12,430 - 4106154625.py[line:69] - INFO: Epoch 4 completed in 2274.61s | Avg Loss: 0.0517\n" + ] + } + ], + "source": [ + "trainer = DiffusionTrainer(\n", + " main_module=main_module, dm=dm, logger=logger, config=config\n", + ")\n", + "trainer.train(total_steps=total_num_steps)" + ] + }, + { + "cell_type": "markdown", + "id": "5be19a94-d51c-487f-bfde-8a9d9dc57d79", + "metadata": {}, + "source": [ + "## 模型评估与可视化\n", + "\n", + "完成训练后,我们使用第5个ckpt进行推理。下述展示了预测值与实际值之间的误差和各项指标以及结果可视化。" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "e0654f48-d02d-4258-b86c-61da318cf10e", + "metadata": {}, + "outputs": [], + "source": [ + "def get_alignment_kwargs_avg_x(target_seq):\n", + " \"\"\"Generate alignment parameters for guided sampling\"\"\"\n", + " batch_size = target_seq.shape[0]\n", + " avg_intensity = mint.mean(target_seq.view(batch_size, -1), dim=1, keepdim=True)\n", + " return {\"avg_x_gt\": avg_intensity * 2.0}\n", + "\n", + "\n", + "class DiffusionInferrence(nn.Cell):\n", + " \"\"\"\n", + " Class managing model inference and evaluation processes. Handles loading checkpoints,\n", + " generating predictions, calculating evaluation metrics, and saving visualization results.\n", + " \"\"\"\n", + " def __init__(self, main_module, dm, logger, config):\n", + " \"\"\"\n", + " Initialize inference manager with model, data module, logger, and configuration.\n", + " Args:\n", + " main_module: Main diffusion model for inference\n", + " dm: Data module providing test dataset\n", + " logger: Logging utility for evaluation progress\n", + " config: Configuration dictionary containing evaluation parameters\n", + " \"\"\"\n", + " super().__init__()\n", + " self.num_samples = config[\"eval\"].get(\"num_samples_per_context\", 1)\n", + " self.eval_example_only = config[\"eval\"].get(\"eval_example_only\", True)\n", + " self.alignment_type = (\n", + " config.get(\"model\", {}).get(\"align\", {}).get(\"alignment_type\", \"avg_x\")\n", + " )\n", + " self.use_alignment = self.alignment_type is not None\n", + " self.eval_aligned = config[\"eval\"].get(\"eval_aligned\", True)\n", + " self.eval_unaligned = config[\"eval\"].get(\"eval_unaligned\", True)\n", + " self.num_samples_per_context = config[\"eval\"].get(\"num_samples_per_context\", 1)\n", + " self.logging_prefix = config[\"logging\"].get(\"logging_prefix\", \"PreDiff\")\n", + " self.test_example_data_idx_list = [48]\n", + " self.main_module = main_module\n", + " self.testdataset = dm.sevir_test\n", + " self.logger = logger\n", + " self.datasetprocessing = SEVIRDataset(\n", + " data_types=[\"vil\"],\n", + " layout=\"NHWT\",\n", + " rescale_method=config.get(\"rescale_method\", \"01\"),\n", + " )\n", + " self.example_save_dir = config[\"summary\"].get(\"summary_dir\", \"./summary\")\n", + "\n", + " self.fs = config[\"eval\"].get(\"fs\", 20)\n", + " self.label_offset = config[\"eval\"].get(\"label_offset\", [-0.5, 0.5])\n", + " self.label_avg_int = config[\"eval\"].get(\"label_avg_int\", False)\n", + "\n", + " self.current_epoch = 0\n", + "\n", + " self.learn_logvar = (\n", + " config.get(\"model\", {}).get(\"diffusion\", {}).get(\"learn_logvar\", False)\n", + " )\n", + " self.logvar = main_module.logvar\n", + " self.maeloss = nn.MAELoss()\n", + " self.test_metrics = {\n", + " \"step\": 0,\n", + " \"mse\": 0.0,\n", + " \"mae\": 0.0,\n", + " \"ssim\": 0.0,\n", + " \"mse_kc\": 0.0,\n", + " \"mae_kc\": 0.0,\n", + " }\n", + "\n", + " def test(self):\n", + " \"\"\"Execute complete evaluation pipeline.\"\"\"\n", + " self.logger.info(\"============== Start Test ==============\")\n", + " self.start_time = time.time()\n", + " for batch_idx, item in enumerate(self.testdataset.create_dict_iterator()):\n", + " self.test_metrics = self._test_onestep(item, batch_idx, self.test_metrics)\n", + "\n", + " self._finalize_test(self.test_metrics)\n", + "\n", + " def _test_onestep(self, item, batch_idx, metrics):\n", + " \"\"\"Process one test batch and update evaluation metrics.\"\"\"\n", + " data_idx = int(batch_idx * 2)\n", + " if not self._should_test_onestep(data_idx):\n", + " return metrics\n", + " data = item.get(\"vil\")\n", + " data = self.datasetprocessing.process_data(data)\n", + " target_seq, cond, context_seq = self._get_model_inputs(data)\n", + " aligned_preds, unaligned_preds = self._generate_predictions(\n", + " cond, target_seq\n", + " )\n", + " metrics = self._update_metrics(\n", + " aligned_preds, unaligned_preds, target_seq, metrics\n", + " )\n", + " self._plt_pred(\n", + " data_idx,\n", + " context_seq,\n", + " target_seq,\n", + " aligned_preds,\n", + " unaligned_preds,\n", + " metrics[\"step\"],\n", + " )\n", + "\n", + " metrics[\"step\"] += 1\n", + " return metrics\n", + "\n", + " def _should_test_onestep(self, data_idx):\n", + " \"\"\"Determine if evaluation should be performed on current data index.\"\"\"\n", + " return (not self.eval_example_only) or (\n", + " data_idx in self.test_example_data_idx_list\n", + " )\n", + "\n", + " def _get_model_inputs(self, data):\n", + " \"\"\"Extract and prepare model inputs from raw data.\"\"\"\n", + " target_seq, cond, context_seq = self.main_module.get_input(\n", + " data, return_verbose=True\n", + " )\n", + " return target_seq, cond, context_seq\n", + "\n", + " def _generate_predictions(self, cond, target_seq):\n", + " \"\"\"Generate both aligned and unaligned predictions from the model.\"\"\"\n", + " aligned_preds = []\n", + " unaligned_preds = []\n", + "\n", + " for _ in range(self.num_samples_per_context):\n", + " if self.use_alignment and self.eval_aligned:\n", + " aligned_pred = self._sample_with_alignment(\n", + " cond, target_seq\n", + " )\n", + " aligned_preds.append(aligned_pred)\n", + "\n", + " if self.eval_unaligned:\n", + " unaligned_pred = self._sample_without_alignment(cond)\n", + " unaligned_preds.append(unaligned_pred)\n", + "\n", + " return aligned_preds, unaligned_preds\n", + "\n", + " def _sample_with_alignment(self, cond, target_seq):\n", + " \"\"\"Generate predictions using alignment mechanism.\"\"\"\n", + " alignment_kwargs = get_alignment_kwargs_avg_x(target_seq)\n", + " pred_seq = self.main_module.sample(\n", + " cond=cond,\n", + " batch_size=cond[\"y\"].shape[0],\n", + " return_intermediates=False,\n", + " use_alignment=True,\n", + " alignment_kwargs=alignment_kwargs,\n", + " verbose=False,\n", + " )\n", + " if pred_seq.dtype != ms.float32:\n", + " pred_seq = pred_seq.float()\n", + " return pred_seq\n", + "\n", + " def _sample_without_alignment(self, cond):\n", + " \"\"\"Generate predictions without alignment.\"\"\"\n", + " pred_seq = self.main_module.sample(\n", + " cond=cond,\n", + " batch_size=cond[\"y\"].shape[0],\n", + " return_intermediates=False,\n", + " verbose=False,\n", + " )\n", + " if pred_seq.dtype != ms.float32:\n", + " pred_seq = pred_seq.float()\n", + " return pred_seq\n", + "\n", + " def _update_metrics(self, aligned_preds, unaligned_preds, target_seq, metrics):\n", + " \"\"\"Update evaluation metrics with new predictions.\"\"\"\n", + " for pred in aligned_preds:\n", + " metrics[\"mse_kc\"] += ops.mse_loss(pred, target_seq)\n", + " metrics[\"mae_kc\"] += self.maeloss(pred, target_seq)\n", + " self.main_module.test_aligned_score.update(pred, target_seq)\n", + "\n", + " for pred in unaligned_preds:\n", + " metrics[\"mse\"] += ops.mse_loss(pred, target_seq)\n", + " metrics[\"mae\"] += self.maeloss(pred, target_seq)\n", + " self.main_module.test_score.update(pred, target_seq)\n", + "\n", + " pred_bchw = self._convert_to_bchw(pred)\n", + " target_bchw = self._convert_to_bchw(target_seq)\n", + " metrics[\"ssim\"] += self.main_module.test_ssim(pred_bchw, target_bchw)[0]\n", + "\n", + " return metrics\n", + "\n", + " def _convert_to_bchw(self, tensor):\n", + " \"\"\"Convert tensor to batch-channel-height-width format for metrics.\"\"\"\n", + " return rearrange(tensor.asnumpy(), \"b t h w c -> (b t) c h w\")\n", + "\n", + " def _plt_pred(\n", + " self, data_idx, context_seq, target_seq, aligned_preds, unaligned_preds, step\n", + " ):\n", + " \"\"\"Generate and save visualization of predictions.\"\"\"\n", + " pred_sequences = [pred[0].asnumpy() for pred in aligned_preds + unaligned_preds]\n", + " pred_labels = [\n", + " f\"{self.logging_prefix}_aligned_pred_{i}\" for i in range(len(aligned_preds))\n", + " ] + [f\"{self.logging_prefix}_pred_{i}\" for i in range(len(unaligned_preds))]\n", + "\n", + " self.save_vis_step_end(\n", + " data_idx=data_idx,\n", + " context_seq=context_seq[0].asnumpy(),\n", + " target_seq=target_seq[0].asnumpy(),\n", + " pred_seq=pred_sequences,\n", + " pred_label=pred_labels,\n", + " mode=\"test\",\n", + " suffix=f\"_step_{step}\",\n", + " )\n", + "\n", + " def _finalize_test(self, metrics):\n", + " \"\"\"Complete test process and log final metrics.\"\"\"\n", + " total_time = (time.time() - self.start_time) * 1000\n", + " self.logger.info(f\"test cost: {total_time:.2f} ms\")\n", + " self._compute_total_metrics(metrics)\n", + " self.logger.info(\"============== Test Completed ==============\")\n", + "\n", + " def _compute_total_metrics(self, metrics):\n", + " \"\"\"log_metrics\"\"\"\n", + " step_count = max(metrics[\"step\"], 1)\n", + " if self.eval_unaligned:\n", + " self.logger.info(f\"MSE: {metrics['mse'] / step_count}\")\n", + " self.logger.info(f\"MAE: {metrics['mae'] / step_count}\")\n", + " self.logger.info(f\"SSIM: {metrics['ssim'] / step_count}\")\n", + " test_score = self.main_module.test_score.eval()\n", + " self.logger.info(\"SCORE:\\n%s\", json.dumps(test_score, indent=4))\n", + " if self.use_alignment:\n", + " self.logger.info(f\"KC_MSE: {metrics['mse_kc'] / step_count}\")\n", + " self.logger.info(f\"KC_MAE: {metrics['mae_kc'] / step_count}\")\n", + " aligned_score = self.main_module.test_aligned_score.eval()\n", + " self.logger.info(\"KC_SCORE:\\n%s\", json.dumps(aligned_score, indent=4))\n", + "\n", + " def save_vis_step_end(\n", + " self,\n", + " data_idx: int,\n", + " context_seq: np.ndarray,\n", + " target_seq: np.ndarray,\n", + " pred_seq: Union[np.ndarray, Sequence[np.ndarray]],\n", + " pred_label: Union[str, Sequence[str]] = None,\n", + " mode: str = \"train\",\n", + " prefix: str = \"\",\n", + " suffix: str = \"\",\n", + " ):\n", + " \"\"\"Save visualization of predictions with context and target.\"\"\"\n", + " example_data_idx_list = self.test_example_data_idx_list\n", + " if isinstance(pred_seq, Sequence):\n", + " seq_list = [context_seq, target_seq] + list(pred_seq)\n", + " label_list = [\"context\", \"target\"] + pred_label\n", + " else:\n", + " seq_list = [context_seq, target_seq, pred_seq]\n", + " label_list = [\"context\", \"target\", pred_label]\n", + " if data_idx in example_data_idx_list:\n", + " png_save_name = f\"{prefix}{mode}_data_{data_idx}{suffix}.png\"\n", + " vis_sevir_seq(\n", + " save_path=os.path.join(self.example_save_dir, png_save_name),\n", + " seq=seq_list,\n", + " label=label_list,\n", + " interval_real_time=10,\n", + " plot_stride=1,\n", + " fs=self.fs,\n", + " label_offset=self.label_offset,\n", + " label_avg_int=self.label_avg_int,\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "bf830d99-eec2-473e-8ffa-900fc2314b22", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2025-04-07 14:04:16,558 - 2610859736.py[line:66] - INFO: ============== Start Test ==============\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[]\n", + ".." + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2025-04-07 14:10:31,931 - 2610859736.py[line:201] - INFO: test cost: 375371.60 ms\n", + "2025-04-07 14:10:31,937 - 2610859736.py[line:215] - INFO: KC_MSE: 0.0036273836\n", + "2025-04-07 14:10:31,939 - 2610859736.py[line:216] - INFO: KC_MAE: 0.017427118\n", + "2025-04-07 14:10:31,955 - 2610859736.py[line:218] - INFO: KC_SCORE:\n", + "{\n", + " \"16\": {\n", + " \"csi\": 0.2715393900871277,\n", + " \"pod\": 0.5063194632530212,\n", + " \"sucr\": 0.369321346282959,\n", + " \"bias\": 3.9119162559509277\n", + " },\n", + " \"74\": {\n", + " \"csi\": 0.15696434676647186,\n", + " \"pod\": 0.17386901378631592,\n", + " \"sucr\": 0.6175059080123901,\n", + " \"bias\": 0.16501028835773468\n", + " }\n", + "}\n", + "2025-04-07 14:10:31,956 - 2610859736.py[line:203] - INFO: ============== Test Completed ==============\n" + ] + } + ], + "source": [ + "main_module.main_model.set_train(False)\n", + "params = ms.load_checkpoint(\"/home/lry/202542测试/PreDiff/summary/prediff/single_device0/ckpt/diffusion_epoch4.ckpt\")\n", + "a, b = ms.load_param_into_net(main_module.main_model, params)\n", + "print(b)\n", + "tester = DiffusionInferrence(\n", + " main_module=main_module, dm=dm, logger=logger, config=config\n", + " )\n", + "tester.test()\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.11" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/MindEarth/applications/nowcasting/PreDiff/src/__init__.py b/MindEarth/applications/nowcasting/PreDiff/src/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7139ede4800b666a02906a5d586ba6ab4f3dcb2f --- /dev/null +++ b/MindEarth/applications/nowcasting/PreDiff/src/__init__.py @@ -0,0 +1,30 @@ +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this filepio[] except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""init""" +from src.diffusion import PreDiffModule, DiffusionTrainer, DiffusionInferrence +from src.utils import ( + prepare_output_directory, + configure_logging_system, + prepare_dataset, + init_model, +) +__all__ = ['prepare_output_directory', + 'configure_logging_system', + 'prepare_dataset', + 'init_model', + 'PreDiffModule', + 'DiffusionTrainer', + 'DiffusionInferrence', + ] diff --git a/MindEarth/applications/nowcasting/PreDiff/src/diffusion/__init__.py b/MindEarth/applications/nowcasting/PreDiff/src/diffusion/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ccf0ef12773899f0e698c9bc103adbc2f3639ea7 --- /dev/null +++ b/MindEarth/applications/nowcasting/PreDiff/src/diffusion/__init__.py @@ -0,0 +1,33 @@ +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this filepio[] except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""init""" +from .time_embed import TimeEmbedLayer, TimeEmbedResBlock +from .cuboid_transformer import PatchMerging3D, PosEmbed, StackCuboidSelfAttentionBlock +from .latent_diffusion import PreDiffModule +from .forecast import DiffusionInferrence +from .solver import DiffusionTrainer +from .cuboid_transformer_unet import self_axial + +__all__ = [ + "TimeEmbedLayer", + "TimeEmbedResBlock", + "PatchMerging3D", + "PosEmbed", + "StackCuboidSelfAttentionBlock", + "PreDiffModule", + "DiffusionInferrence", + "DiffusionTrainer", + "self_axial" +] diff --git a/MindEarth/applications/nowcasting/PreDiff/src/diffusion/cuboid_transformer.py b/MindEarth/applications/nowcasting/PreDiff/src/diffusion/cuboid_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..b51fb72cd72e32e2b799ba51156782ea9359eca9 --- /dev/null +++ b/MindEarth/applications/nowcasting/PreDiff/src/diffusion/cuboid_transformer.py @@ -0,0 +1,1062 @@ +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""The Substructure of cuboid_transformer_unet""" +from functools import lru_cache +from collections import OrderedDict + +import mindspore as ms +from mindspore import nn, ops, Parameter +from mindspore.common.initializer import initializer, TruncatedNormal + +from src.utils import ( + get_activation, + get_norm_layer, + generalize_padding, + generalize_unpadding, + apply_initialization, +) + + +class PosEmbed(nn.Cell): + """ + Spatiotemporal positional embedding layer combining temporal, height, and width embeddings. + """ + def __init__(self, embed_dim, max_t, max_h, max_w): + """ + Initialize positional embedding with separate temporal/spatial components. + Args: + embed_dim (int): Dimensionality of the embedding vectors. + maxT (int): Maximum temporal length (number of time steps). + maxH (int): Maximum height dimension size. + maxW (int): Maximum width dimension size. + """ + super().__init__() + self.embed_dim = embed_dim + # spatiotemporal learned positional embedding + self.t_embed = nn.Embedding(vocab_size=max_t, embedding_size=embed_dim) + self.h_embed = nn.Embedding(vocab_size=max_h, embedding_size=embed_dim) + self.w_embed = nn.Embedding(vocab_size=max_w, embedding_size=embed_dim) + self.reset_parameters() + + def reset_parameters(self): + for cell in self.cells(): + apply_initialization(cell, embed_mode="0") + + def construct(self, x): + """Forward pass of positional embedding. + + Args: + x (Tensor): Input tensor of shape (B, T, H, W, C) + + Returns: + Tensor: Output tensor with added positional embeddings + """ + + _, t, h, w, _ = x.shape + + t_idx = ops.arange(t) + h_idx = ops.arange(h) + w_idx = ops.arange(w) + return ( + x + + self.t_embed(t_idx).reshape(t, 1, 1, self.embed_dim) + + self.h_embed(h_idx).reshape(1, h, 1, self.embed_dim) + + self.w_embed(w_idx).reshape(1, 1, w, self.embed_dim) + ) + + +class PositionwiseFFN(nn.Cell): + """The Position-wise Feed-Forward Network layer used in Transformer architectures. + + This implements a two-layer MLP with optional gating mechanism and normalization. + The processing order depends on the pre_norm parameter: + + If pre_norm is True: + norm(data) -> fc1 -> act -> act_dropout -> fc2 -> dropout -> residual_add(+data) + Else: + data -> fc1 -> act -> act_dropout -> fc2 -> dropout -> norm(residual_add(+data)) + + When gated projection is enabled, uses: + fc1_1 * act(fc1_2(data)) for the first projection + """ + + def __init__( + self, + units: int = 512, + hidden_size: int = 2048, + activation_dropout: float = 0.0, + dropout: float = 0.1, + gated_proj: bool = False, + activation="relu", + normalization: str = "layer_norm", + layer_norm_eps: float = 1e-5, + pre_norm: bool = False, + linear_init_mode="0", + ffn2_linear_init_mode="2", + norm_init_mode="0", + ): + super().__init__() + self.linear_init_mode = linear_init_mode + self.ffn2_linear_init_mode = ffn2_linear_init_mode + self.norm_init_mode = norm_init_mode + + self._pre_norm = pre_norm + self._gated_proj = gated_proj + self._kwargs = OrderedDict( + [ + ("units", units), + ("hidden_size", hidden_size), + ("activation_dropout", activation_dropout), + ("activation", activation), + ("dropout", dropout), + ("normalization", normalization), + ("layer_norm_eps", layer_norm_eps), + ("gated_proj", gated_proj), + ("pre_norm", pre_norm), + ] + ) + self.dropout_layer = nn.Dropout(p=dropout) + self.activation_dropout_layer = nn.Dropout(p=activation_dropout) + self.ffn_1 = nn.Dense( + in_channels=units, out_channels=hidden_size, has_bias=True + ) + if self._gated_proj: + self.ffn_1_gate = nn.Dense( + in_channels=units, out_channels=hidden_size, has_bias=True + ) + self.activation = get_activation(activation) + self.ffn_2 = nn.Dense( + in_channels=hidden_size, out_channels=units, has_bias=True + ) + self.layer_norm = get_norm_layer( + norm_type=normalization, in_channels=units, epsilon=layer_norm_eps + ) + self.reset_parameters() + + def reset_parameters(self): + """Initialize all sublayers with specified initialization modes.""" + apply_initialization(self.ffn_1, linear_mode=self.linear_init_mode) + if self._gated_proj: + apply_initialization(self.ffn_1_gate, linear_mode=self.linear_init_mode) + apply_initialization(self.ffn_2, linear_mode=self.ffn2_linear_init_mode) + apply_initialization(self.layer_norm, norm_mode=self.norm_init_mode) + + def construct(self, data): + """ + Forward pass of the Position-wise FFN. + + Args: + data: Input tensor of shape (batch_size, sequence_length, units) + + Returns: + Output tensor of same shape as input with transformed features + """ + residual = data + if self._pre_norm: + data = self.layer_norm(data) + if self._gated_proj: + out = self.activation(self.ffn_1_gate(data)) * self.ffn_1(data) + else: + out = self.activation(self.ffn_1(data)) + out = self.activation_dropout_layer(out) + out = self.ffn_2(out) + out = self.dropout_layer(out) + out = out + residual + if not self._pre_norm: + out = self.layer_norm(out) + return out + + +class PatchMerging3D(nn.Cell): + """3D Patch Merging Layer for spatial-temporal feature downsampling. + This layer merges patches in 3D (temporal, height, width) and applies a linear transformation + to reduce the feature dimension while increasing the channel dimension. + """ + + def __init__( + self, + dim, + out_dim=None, + downsample=(1, 2, 2), + norm_layer="layer_norm", + padding_type="nearest", + linear_init_mode="0", + norm_init_mode="0", + ): + super().__init__() + self.linear_init_mode = linear_init_mode + self.norm_init_mode = norm_init_mode + self.dim = dim + if out_dim is None: + out_dim = max(downsample) * dim + self.out_dim = out_dim + self.downsample = downsample + self.padding_type = padding_type + self.reduction = nn.Dense( + downsample[0] * downsample[1] * downsample[2] * dim, out_dim, has_bias=False + ) + self.norm = get_norm_layer( + norm_layer, in_channels=downsample[0] * downsample[1] * downsample[2] * dim + ) + self.reset_parameters() + + def reset_parameters(self): + """Initialize all sublayers with specified initialization modes.""" + for cell in self.cells(): + apply_initialization( + cell, linear_mode=self.linear_init_mode, norm_mode=self.norm_init_mode + ) + + def get_out_shape(self, data_shape): + """ + Calculate the output shape given input dimensions. + + Args: + data_shape: Input shape tuple (T, H, W, C_in) + + Returns: + Tuple of output shape (T_out, H_out, W_out, C_out) + """ + t, h, w, _ = data_shape + pad_t = (self.downsample[0] - t % self.downsample[0]) % self.downsample[0] + pad_h = (self.downsample[1] - h % self.downsample[1]) % self.downsample[1] + pad_w = (self.downsample[2] - w % self.downsample[2]) % self.downsample[2] + return ( + (t + pad_t) // self.downsample[0], + (h + pad_h) // self.downsample[1], + (w + pad_w) // self.downsample[2], + self.out_dim, + ) + + def construct(self, x): + """ + Forward pass of the 3D Patch Merging layer. + + Args: + x: Input tensor of shape (B, T, H, W, C) + + Returns: + Output tensor of shape: + (B, T//downsample[0], H//downsample[1], W//downsample[2], out_dim) + """ + b, t, h, w, c = x.shape + + # padding + pad_t = (self.downsample[0] - t % self.downsample[0]) % self.downsample[0] + pad_h = (self.downsample[1] - h % self.downsample[1]) % self.downsample[1] + pad_w = (self.downsample[2] - w % self.downsample[2]) % self.downsample[2] + if pad_h or pad_h or pad_w: + t += pad_t + h += pad_h + w += pad_w + x = generalize_padding( + x, pad_t, pad_h, pad_w, padding_type=self.padding_type + ) + + x = ( + x.reshape( + ( + b, + t // self.downsample[0], + self.downsample[0], + h // self.downsample[1], + self.downsample[1], + w // self.downsample[2], + self.downsample[2], + c, + ) + ) + .permute(0, 1, 3, 5, 2, 4, 6, 7) + .reshape( + b, + t // self.downsample[0], + h // self.downsample[1], + w // self.downsample[2], + self.downsample[0] * self.downsample[1] * self.downsample[2] * c, + ) + ) + x = self.norm(x) + x = self.reduction(x) + + return x + + +class Upsample3DLayer(nn.Cell): + """3D Upsampling Layer combining interpolation and convolution. + + Performs spatial upsampling (with optional temporal upsampling) followed by convolution. + The operation consists of: + 1. Spatial upsampling using nearest-neighbor interpolation + 2. 2D or 3D convolution to refine features and adjust channel dimensions + + Note: Currently only implements 2D upsampling (spatial only) + """ + + def __init__( + self, + dim, + out_dim, + target_size, + kernel_size=3, + conv_init_mode="0", + ): + super().__init__() + self.conv_init_mode = conv_init_mode + self.target_size = target_size + self.out_dim = out_dim + self.up = nn.Upsample(size=(target_size[1], target_size[2]), mode="nearest") + self.conv = nn.Conv2d( + in_channels=dim, + out_channels=out_dim, + kernel_size=(kernel_size, kernel_size), + padding=kernel_size // 2, + has_bias=True, + pad_mode="pad", + ) + self.reset_parameters() + + def reset_parameters(self): + """Initialize all sublayers with specified initialization modes.""" + for cell in self.cells(): + apply_initialization(cell, conv_mode=self.conv_init_mode) + + def construct(self, x): + """Forward pass of the 3D Upsampling layer.""" + b, t, h, w, c = x.shape + assert self.target_size[0] == t + x = x.reshape(b * t, h, w, c).permute(0, 3, 1, 2) + x = self.up(x) + return ( + self.conv(x) + .permute(0, 2, 3, 1) + .reshape((b,) + self.target_size + (self.out_dim,)) + ) + + +def cuboid_reorder(data, cuboid_size, strategy): + """Reorder the tensor into (B, num_cuboids, bT * bH * bW, C) + + We assume that the tensor shapes are divisible to the cuboid sizes. + + Parameters + ---------- + data + The input data + cuboid_size + The size of the cuboid + strategy + The cuboid strategy + + Returns + ------- + reordered_data + Shape will be (B, num_cuboids, bT * bH * bW, C) + num_cuboids = T / bT * H / bH * W / bW + """ + b, t, h, w, c = data.shape + num_cuboids = t // cuboid_size[0] * h // cuboid_size[1] * w // cuboid_size[2] + cuboid_volume = cuboid_size[0] * cuboid_size[1] * cuboid_size[2] + intermediate_shape = [] + + nblock_axis = [] + block_axis = [] + for i, (block_size, total_size, ele_strategy) in enumerate( + zip(cuboid_size, (t, h, w), strategy) + ): + if ele_strategy == "l": + intermediate_shape.extend([total_size // block_size, block_size]) + nblock_axis.append(2 * i + 1) + block_axis.append(2 * i + 2) + elif ele_strategy == "d": + intermediate_shape.extend([block_size, total_size // block_size]) + nblock_axis.append(2 * i + 2) + block_axis.append(2 * i + 1) + else: + raise NotImplementedError + + a = (b,) + tuple(intermediate_shape) + (c,) + data = data.reshape(a) + reordered_data = data.permute((0,) + tuple(nblock_axis) + tuple(block_axis) + (7,)) + reordered_data = reordered_data.reshape((b, num_cuboids, cuboid_volume, c)) + return reordered_data + + +def cuboid_reorder_reverse(data, cuboid_size, strategy, orig_data_shape): + """Reverse the reordered cuboid back to the original space + + Parameters + ---------- + data + cuboid_size + strategy + orig_data_shape + + Returns + ------- + data + The recovered data + """ + b, _, _, c = data.shape + t, h, w = orig_data_shape + + permutation_axis = [0] + for i, (_, _, ele_strategy) in enumerate( + zip(cuboid_size, (t, h, w), strategy) + ): + if ele_strategy == "l": + permutation_axis.append(i + 1) + permutation_axis.append(i + 4) + elif ele_strategy == "d": + permutation_axis.append(i + 4) + permutation_axis.append(i + 1) + else: + raise NotImplementedError + permutation_axis.append(7) + data = data.reshape( + b, + t // cuboid_size[0], + h // cuboid_size[1], + w // cuboid_size[2], + cuboid_size[0], + cuboid_size[1], + cuboid_size[2], + c, + ) + data = data.permute(permutation_axis) + data = data.reshape((b, t, h, w, c)) + return data + + +@lru_cache() +def compute_cuboid_self_attention_mask( + data_shape, cuboid_size, shift_size, strategy, padding_type +): + """compute_cuboid_self_attention_mask""" + t, h, w = data_shape + pad_t = (cuboid_size[0] - t % cuboid_size[0]) % cuboid_size[0] + pad_h = (cuboid_size[1] - h % cuboid_size[1]) % cuboid_size[1] + pad_w = (cuboid_size[2] - w % cuboid_size[2]) % cuboid_size[2] + + data_mask = None + if pad_t > 0 or pad_h > 0 or pad_w > 0: + if padding_type == "ignore": + data_mask = ops.ones((1, t, h, w, 1), dtype=ms.bool_) + data_mask = ops.pad( + data_mask, ((0, 0), (0, pad_t), (0, pad_h), (0, pad_w), (0, 0)) + ) + else: + data_mask = ops.ones((1, t + pad_t, h + pad_h, w + pad_w, 1), dtype=ms.bool_) + + if any(i > 0 for i in shift_size): + if padding_type == "ignore": + data_mask = ops.roll( + data_mask, + shifts=(-shift_size[0], -shift_size[1], -shift_size[2]), + dims=(1, 2, 3), + ) + t_padded, h_padded, w_padded = t + pad_t, h + pad_h, w + pad_w + if t_padded <= 0 or h_padded <= 0 or w_padded <= 0: + raise ValueError( + f"invalid padded dimensions: t={t_padded}, h={h_padded}, w={w_padded}" + ) + + shift_mask = ops.zeros((1, t_padded, h_padded, w_padded, 1)) + cnt = 0 + t_slices = ( + [ + slice(0, cuboid_size[0]), + slice(cuboid_size[0] - shift_size[0], t_padded - shift_size[0]), + slice(t_padded - cuboid_size[0], t_padded), + ] + if shift_size[0] > 0 + else [slice(0, t_padded)] + ) + + h_slices = ( + [ + slice(0, cuboid_size[1]), + slice(cuboid_size[1] - shift_size[1], h_padded - shift_size[1]), + slice(h_padded - cuboid_size[1], h_padded), + ] + if shift_size[1] > 0 + else [slice(0, h_padded)] + ) + + w_slices = ( + [ + slice(0, cuboid_size[2]), + slice(cuboid_size[2] - shift_size[2], w_padded - shift_size[2]), + slice(w_padded - cuboid_size[2], w_padded), + ] + if shift_size[2] > 0 + else [slice(0, w_padded)] + ) + + for t in t_slices: + for h in h_slices: + for w in w_slices: + shift_mask[:, t, h, w, :] = cnt + cnt += 1 + + shift_mask = cuboid_reorder(shift_mask, cuboid_size, strategy=strategy) + shift_mask = shift_mask.squeeze(-1).squeeze(0) # num_cuboids, cuboid_volume + attn_mask = (shift_mask.unsqueeze(1) - shift_mask.unsqueeze(2)) == 0 + + if padding_type == "ignore": + if padding_type == "ignore": + data_mask = cuboid_reorder(data_mask, cuboid_size, strategy=strategy) + data_mask = data_mask.squeeze(-1).squeeze(0) + attn_mask = data_mask.unsqueeze(1) * data_mask.unsqueeze(2) * attn_mask + + return attn_mask + + +def masked_softmax(att_score, mask, axis: int = -1): + """Computes softmax while ignoring masked elements with broadcastable masks. + + Parameters + ---------- + att_score : Tensor + mask : Tensor or None + Binary mask tensor of shape (..., length, ...) where: + - 1 indicates unmasked (valid) elements + - 0 indicates masked elements + Must be broadcastable with att_score + axis : int, optional + + Returns + ------- + Tensor + Softmax output of same shape as input att_score, with: + - Proper attention weights for unmasked elements + - Zero weights for masked elements + """ + if mask is not None: + # Fill in the masked scores with a very small value + if att_score.dtype == ms.float16: + att_score = att_score.masked_fill(ops.logical_not(mask), -1e4) + else: + att_score = att_score.masked_fill(ops.logical_not(mask), -1e18) + att_weights = ops.softmax(att_score, axis=axis) * mask + else: + att_weights = ops.softmax(att_score, axis=axis) + return att_weights + + +def update_cuboid_size_shift_size(data_shape, cuboid_size, shift_size, strategy): + """Update the + + Parameters + ---------- + data_shape + The shape of the data + cuboid_size + Size of the cuboid + shift_size + Size of the shift + strategy + The strategy of attention + + Returns + ------- + new_cuboid_size + Size of the cuboid + new_shift_size + Size of the shift + """ + new_cuboid_size = list(cuboid_size) + new_shift_size = list(shift_size) + for i in range(len(data_shape)): + if strategy[i] == "d": + new_shift_size[i] = 0 + if data_shape[i] <= cuboid_size[i]: + new_cuboid_size[i] = data_shape[i] + new_shift_size[i] = 0 + return tuple(new_cuboid_size), tuple(new_shift_size) + + +class CuboidSelfAttentionLayer(nn.Cell): + """ + A self-attention layer designed for 3D data (e.g., video or 3D images), + implementing cuboid-based attention with optional global vectors and relative position encoding. + """ + def __init__( + self, + dim, + num_heads, + cuboid_size=(2, 7, 7), + shift_size=(0, 0, 0), + strategy=("l", "l", "l"), + padding_type="ignore", + qkv_bias=False, + qk_scale=None, + attn_drop=0.0, + proj_drop=0.0, + use_final_proj=True, + norm_layer="layer_norm", + use_global_vector=False, + use_global_self_attn=False, + separate_global_qkv=False, + global_dim_ratio=1, + use_relative_pos=True, + attn_linear_init_mode="0", + ffn_linear_init_mode="2", + norm_init_mode="0", + ): + """Initialize the CuboidSelfAttentionLayer. + + Args: + dim (int): Input feature dimension. + num_heads (int): Number of attention heads. + cuboid_size (tuple): 3D dimensions (T, H, W) of the cuboid blocks. + shift_size (tuple): Shift sizes for each dimension to avoid attention blindness. + strategy (tuple): Strategy for each dimension ('l' for local, 'g' for global). + padding_type (str): Padding method for attention computation ("ignore", "zeros", "nearest"). + qkv_bias (bool): Whether to include bias in QKV projections. + qk_scale (float, optional): Scaling factor for QK dot product. Defaults to head_dim**-0.5. + attn_drop (float): Dropout rate after attention softmax. + proj_drop (float): Dropout rate after output projection. + use_final_proj (bool): Whether to apply the final linear projection. + norm_layer (str): Type of normalization layer ("layer_norm", etc.). + use_global_vector (bool): Whether to include a global vector in attention. + use_global_self_attn (bool): Whether to apply self-attention to global vectors. + separate_global_qkv (bool): Whether to use separate QKV for global vectors. + global_dim_ratio (int): Dimension ratio for global vector (requires separate_global_qkv=True if !=1). + use_relative_pos (bool): Whether to use relative position embeddings. + attn_linear_init_mode (str): Initialization mode for attention linear layers. + ffn_linear_init_mode (str): Initialization mode for FFN linear layers. + norm_init_mode (str): Initialization mode for normalization layers. + """ + super().__init__() + # initialization + self.attn_linear_init_mode = attn_linear_init_mode + self.ffn_linear_init_mode = ffn_linear_init_mode + self.norm_init_mode = norm_init_mode + + assert dim % num_heads == 0 + self.num_heads = num_heads + self.dim = dim + self.cuboid_size = cuboid_size + self.shift_size = shift_size + self.strategy = strategy + self.padding_type = padding_type + self.use_final_proj = use_final_proj + self.use_relative_pos = use_relative_pos + # global vectors + self.use_global_vector = use_global_vector + self.use_global_self_attn = use_global_self_attn + self.separate_global_qkv = separate_global_qkv + self.global_dim_ratio = global_dim_ratio + assert self.padding_type in ["ignore", "zeros", "nearest"] + head_dim = dim // num_heads + self.scale = qk_scale or head_dim**-0.5 + + if self.use_relative_pos: + self.relative_position_bias_table = Parameter( + initializer( + TruncatedNormal(sigma=0.02), + [ + (2 * cuboid_size[0] - 1) + * (2 * cuboid_size[1] - 1) + * (2 * cuboid_size[2] - 1), + num_heads, + ], + ms.float32, + ) + ) + self.relative_position_bias_table.name = "relative_position_bias_table" + coords_t = ops.arange(self.cuboid_size[0]) + coords_h = ops.arange(self.cuboid_size[1]) + coords_w = ops.arange(self.cuboid_size[2]) + coords = ops.stack(ops.meshgrid(coords_t, coords_h, coords_w)) + + coords_flatten = ops.flatten(coords, start_dim=1) + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] + relative_coords = relative_coords.permute(1, 2, 0) + relative_coords[:, :, 0] += self.cuboid_size[0] - 1 + relative_coords[:, :, 1] += self.cuboid_size[1] - 1 + relative_coords[:, :, 2] += self.cuboid_size[2] - 1 + + relative_coords[:, :, 0] *= (2 * self.cuboid_size[1] - 1) * ( + 2 * self.cuboid_size[2] - 1 + ) + relative_coords[:, :, 1] *= 2 * self.cuboid_size[2] - 1 + relative_position_index = relative_coords.sum(-1) + self.relative_position_index = Parameter( + relative_position_index, + name="relative_position_index", + requires_grad=False, + ) + self.qkv = nn.Dense(dim, dim * 3, has_bias=qkv_bias) + self.attn_drop = nn.Dropout(p=attn_drop) + + if use_final_proj: + self.proj = nn.Dense(dim, dim) + self.proj_drop = nn.Dropout(p=proj_drop) + + if self.use_global_vector: + self.global_proj = nn.Dense( + in_channels=global_dim_ratio * dim, + out_channels=global_dim_ratio * dim, + ) + + self.norm = get_norm_layer(norm_layer, in_channels=dim) + if self.use_global_vector: + self.global_vec_norm = get_norm_layer( + norm_layer, in_channels=global_dim_ratio * dim + ) + + self.reset_parameters() + + def reset_parameters(self): + '''set_parameters''' + apply_initialization(self.qkv, linear_mode=self.attn_linear_init_mode) + if self.use_final_proj: + apply_initialization(self.proj, linear_mode=self.ffn_linear_init_mode) + apply_initialization(self.norm, norm_mode=self.norm_init_mode) + if self.use_global_vector: + if self.separate_global_qkv: + apply_initialization( + self.l2g_q_net, linear_mode=self.attn_linear_init_mode + ) + apply_initialization( + self.l2g_global_kv_net, linear_mode=self.attn_linear_init_mode + ) + apply_initialization( + self.g2l_global_q_net, linear_mode=self.attn_linear_init_mode + ) + apply_initialization( + self.g2l_k_net, linear_mode=self.attn_linear_init_mode + ) + apply_initialization( + self.g2l_v_net, linear_mode=self.attn_linear_init_mode + ) + if self.use_global_self_attn: + apply_initialization( + self.g2g_global_qkv_net, linear_mode=self.attn_linear_init_mode + ) + else: + apply_initialization( + self.global_qkv, linear_mode=self.attn_linear_init_mode + ) + apply_initialization(self.global_vec_norm, norm_mode=self.norm_init_mode) + + def construct(self, x): + """ + Constructs the output by applying normalization, padding, shifting, and attention mechanisms. + + Parameters: + - x (Tensor): Input tensor with shape (batch, time, height, width, channels). + - global_vectors (Tensor, optional): Global vectors used in global-local interactions. Defaults to None. + + Returns: + - Tensor: Processed tensor after applying all transformations. + - Tensor: Updated global vectors if global vectors are used; otherwise, returns only the processed tensor. + """ + x = self.norm(x) + batch, time, height, width, channels = x.shape + assert channels == self.dim + cuboid_size, shift_size = update_cuboid_size_shift_size( + (time, height, width), self.cuboid_size, self.shift_size, self.strategy + ) + pad_t = (cuboid_size[0] - time % cuboid_size[0]) % cuboid_size[0] + pad_h = (cuboid_size[1] - height % cuboid_size[1]) % cuboid_size[1] + pad_w = (cuboid_size[2] - width % cuboid_size[2]) % cuboid_size[2] + x = generalize_padding(x, pad_t, pad_h, pad_w, self.padding_type) + if any(i > 0 for i in shift_size): + shifted_x = ops.roll( + x, + shifts=(-shift_size[0], -shift_size[1], -shift_size[2]), + dims=(1, 2, 3), + ) + else: + shifted_x = x + reordered_x = cuboid_reorder( + shifted_x, cuboid_size=cuboid_size, strategy=self.strategy + ) + _, num_cuboids, cuboid_volume, _ = reordered_x.shape + attn_mask = compute_cuboid_self_attention_mask( + (time, height, width), + cuboid_size, + shift_size=shift_size, + strategy=self.strategy, + padding_type=self.padding_type, + ) + head_c = channels // self.num_heads + qkv = ( + self.qkv(reordered_x) + .reshape(batch, num_cuboids, cuboid_volume, 3, self.num_heads, head_c) + .permute(3, 0, 4, 1, 2, 5) + ) + q, k, v = ( + qkv[0], + qkv[1], + qkv[2], + ) + q = q * self.scale + attn_score = q @ k.swapaxes(-2, -1) + if self.use_relative_pos: + relative_position_bias = self.relative_position_bias_table[ + self.relative_position_index[:cuboid_volume, :cuboid_volume].reshape(-1) + ].reshape(cuboid_volume, cuboid_volume, -1) + relative_position_bias = relative_position_bias.permute(2, 0, 1).unsqueeze( + 1 + ) + attn_score = attn_score + relative_position_bias + attn_score = masked_softmax(attn_score, mask=attn_mask) + attn_score = self.attn_drop(attn_score) + reordered_x = ( + (attn_score @ v) + .permute(0, 2, 3, 1, 4) + .reshape(batch, num_cuboids, cuboid_volume, self.dim) + ) + + if self.use_final_proj: + reordered_x = self.proj_drop(self.proj(reordered_x)) + if self.use_global_vector: + new_global_vector = self.proj_drop(self.global_proj(new_global_vector)) + shifted_x = cuboid_reorder_reverse( + reordered_x, + cuboid_size=cuboid_size, + strategy=self.strategy, + orig_data_shape=(time + pad_t, height + pad_h, width + pad_w), + ) + if any(i > 0 for i in shift_size): + x = ops.roll( + shifted_x, + shifts=(shift_size[0], shift_size[1], shift_size[2]), + dims=(1, 2, 3), + ) + else: + x = shifted_x + x = generalize_unpadding( + x, pad_t=pad_t, pad_h=pad_h, pad_w=pad_w, padding_type=self.padding_type + ) + if self.use_global_vector: + return x, new_global_vector + return x + + +class StackCuboidSelfAttentionBlock(nn.Cell): + """ + + - "use_inter_ffn" is True + x --> attn1 --> ffn1 --> attn2 --> ... --> ffn_k --> out + - "use_inter_ffn" is False + x --> attn1 --> attn2 --> ... attnk --> ffnk --> out + If we have enabled global memory vectors, each attention will be a + + """ + + def __init__( + self, + dim=None, + num_heads=None, + block_cuboid_size=None, + block_shift_size=None, + block_strategy=None, + padding_type="ignore", + qkv_bias=False, + qk_scale=None, + attn_drop=0.0, + proj_drop=0.0, + ffn_drop=0.0, + activation="leaky", + gated_ffn=False, + norm_layer="layer_norm", + use_inter_ffn=False, + use_global_vector=False, + use_global_vector_ffn=True, + use_global_self_attn=False, + separate_global_qkv=False, + global_dim_ratio=1, + use_relative_pos=True, + use_final_proj=True, + # initialization + attn_linear_init_mode="0", + ffn_linear_init_mode="0", + ffn2_linear_init_mode="2", + attn_proj_linear_init_mode="2", + norm_init_mode="0", + ): + super().__init__() + # initialization + self.attn_linear_init_mode = attn_linear_init_mode + self.ffn_linear_init_mode = ffn_linear_init_mode + self.attn_proj_linear_init_mode = attn_proj_linear_init_mode + self.norm_init_mode = norm_init_mode + self.num_attn = len(block_cuboid_size) + self.use_inter_ffn = use_inter_ffn + # global vectors + self.use_global_vector = use_global_vector + self.use_global_vector_ffn = use_global_vector_ffn + self.use_global_self_attn = use_global_self_attn + self.global_dim_ratio = global_dim_ratio + + if self.use_inter_ffn: + self.ffn_l = nn.CellList( + [ + PositionwiseFFN( + units=dim, + hidden_size=4 * dim, + activation_dropout=ffn_drop, + dropout=ffn_drop, + gated_proj=gated_ffn, + activation=activation, + normalization=norm_layer, + pre_norm=True, + linear_init_mode=ffn_linear_init_mode, + ffn2_linear_init_mode=ffn2_linear_init_mode, + norm_init_mode=norm_init_mode, + ) + for _ in range(self.num_attn) + ] + ) + if self.use_global_vector_ffn and self.use_global_vector: + self.global_ffn_l = nn.CellList( + [ + PositionwiseFFN( + units=global_dim_ratio * dim, + hidden_size=global_dim_ratio * 4 * dim, + activation_dropout=ffn_drop, + dropout=ffn_drop, + gated_proj=gated_ffn, + activation=activation, + normalization=norm_layer, + pre_norm=True, + linear_init_mode=ffn_linear_init_mode, + ffn2_linear_init_mode=ffn2_linear_init_mode, + norm_init_mode=norm_init_mode, + ) + for _ in range(self.num_attn) + ] + ) + else: + self.ffn_l = nn.CellList( + [ + PositionwiseFFN( + units=dim, + hidden_size=4 * dim, + activation_dropout=ffn_drop, + dropout=ffn_drop, + gated_proj=gated_ffn, + activation=activation, + normalization=norm_layer, + pre_norm=True, + linear_init_mode=ffn_linear_init_mode, + ffn2_linear_init_mode=ffn2_linear_init_mode, + norm_init_mode=norm_init_mode, + ) + ] + ) + if self.use_global_vector_ffn and self.use_global_vector: + self.global_ffn_l = nn.CellList( + [ + PositionwiseFFN( + units=global_dim_ratio * dim, + hidden_size=global_dim_ratio * 4 * dim, + activation_dropout=ffn_drop, + dropout=ffn_drop, + gated_proj=gated_ffn, + activation=activation, + normalization=norm_layer, + pre_norm=True, + linear_init_mode=ffn_linear_init_mode, + ffn2_linear_init_mode=ffn2_linear_init_mode, + norm_init_mode=norm_init_mode, + ) + ] + ) + self.attn_l = nn.CellList( + [ + CuboidSelfAttentionLayer( + dim=dim, + num_heads=num_heads, + cuboid_size=ele_cuboid_size, + shift_size=ele_shift_size, + strategy=ele_strategy, + padding_type=padding_type, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=proj_drop, + norm_layer=norm_layer, + use_global_vector=use_global_vector, + use_global_self_attn=use_global_self_attn, + separate_global_qkv=separate_global_qkv, + global_dim_ratio=global_dim_ratio, + use_relative_pos=use_relative_pos, + use_final_proj=use_final_proj, + attn_linear_init_mode=attn_linear_init_mode, + ffn_linear_init_mode=attn_proj_linear_init_mode, + norm_init_mode=norm_init_mode, + ) + for ele_cuboid_size, ele_shift_size, ele_strategy in zip( + block_cuboid_size, block_shift_size, block_strategy + ) + ] + ) + + def reset_parameters(self): + for m in self.ffn_l: + m.reset_parameters() + if self.use_global_vector_ffn and self.use_global_vector: + for m in self.global_ffn_l: + m.reset_parameters() + for m in self.attn_l: + m.reset_parameters() + + def construct(self, x, global_vectors=None): + """ + Constructs the network output by processing input data with attention and feed-forward layers. + + Args: + x (Tensor): Input data tensor. + global_vectors (Tensor, optional): Global vectors for contextual processing. Defaults to None. + + Returns: + Union[Tensor, Tuple[Tensor, Tensor]]: + - If `global_vectors` is used, returns a tuple (processed_x, updated_global_vectors). + - Otherwise, returns the processed input tensor x. + """ + if self.use_inter_ffn: + if self.use_global_vector: + for idx, (attn, ffn) in enumerate(zip(self.attn_l, self.ffn_l)): + x_out, global_vectors_out = attn(x, global_vectors) + x = x + x_out + global_vectors = global_vectors + global_vectors_out + x = ffn(x) + if self.use_global_vector_ffn: + global_vectors = self.global_ffn_l[idx](global_vectors) + return x, global_vectors + for idx, (attn, ffn) in enumerate(zip(self.attn_l, self.ffn_l)): + x_ = attn(x) + x = x + x_ + x = ffn(x) + return x + if self.use_global_vector: + for idx, attn in enumerate(self.attn_l): + x_out, global_vectors_out = attn(x, global_vectors) + x = x + x_out + global_vectors = global_vectors + global_vectors_out + x = self.ffn_l[0](x) + if self.use_global_vector_ffn: + global_vectors = self.global_ffn_l[0](global_vectors) + return x, global_vectors + for idx, attn in enumerate(self.attn_l): + out = attn(x) + x = x + out + x = self.ffn_l[0](x) + return x diff --git a/MindEarth/applications/nowcasting/PreDiff/src/diffusion/cuboid_transformer_unet.py b/MindEarth/applications/nowcasting/PreDiff/src/diffusion/cuboid_transformer_unet.py new file mode 100644 index 0000000000000000000000000000000000000000..ae3292ceb0fffed8c4e78e663b53addf3cb44b0b --- /dev/null +++ b/MindEarth/applications/nowcasting/PreDiff/src/diffusion/cuboid_transformer_unet.py @@ -0,0 +1,527 @@ +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"CuboidTransformerUNet base class" +from mindspore import ops, nn, Parameter +import mindspore.common.initializer as initializer + +from src.utils import timestep_embedding, apply_initialization, round_to, self_axial +from .time_embed import TimeEmbedLayer, TimeEmbedResBlock +from .cuboid_transformer import ( + PosEmbed, + Upsample3DLayer, + PatchMerging3D, + StackCuboidSelfAttentionBlock, +) + + +class CuboidTransformerUNet(nn.Cell): + r""" + U-Net style CuboidTransformer that parametrizes `p(x_{t-1}|x_t)`. + It takes `x_t`, `t` as input. + The conditioning can be concatenated to the input like the U-Net in FVD paper. + + For each block, we apply the StackCuboidSelfAttention in U-Net style + + x --> attn --> downscale --> ... --> z --> attn --> upscale --> ... --> out + + Besides, we insert the embeddings of the timesteps `t` before each cuboid attention blocks. + """ + + def __init__( + self, + input_shape=None, + target_shape=None, + base_units=256, + block_units=None, + scale_alpha=1.0, + depth=None, + downsample=2, + downsample_type="patch_merge", + upsample_type="upsample", + upsample_kernel_size=3, + use_attn_pattern=True, + block_cuboid_size=None, + block_cuboid_strategy=None, + block_cuboid_shift_size=None, + num_heads=4, + attn_drop=0.0, + proj_drop=0.0, + ffn_drop=0.0, + ffn_activation="leaky", + gated_ffn=False, + norm_layer="layer_norm", + use_inter_ffn=True, + hierarchical_pos_embed=False, + padding_type="ignore", + use_relative_pos=True, + self_attn_use_final_proj=True, + # global vectors + num_global_vectors=False, + use_global_vector_ffn=True, + use_global_self_attn=False, + separate_global_qkv=False, + global_dim_ratio=1, + # initialization + attn_linear_init_mode="0", + ffn_linear_init_mode="0", + ffn2_linear_init_mode="2", + attn_proj_linear_init_mode="2", + conv_init_mode="0", + down_linear_init_mode="0", + global_proj_linear_init_mode="2", + norm_init_mode="0", + # timestep embedding for diffusion + time_embed_channels_mult=4, + time_embed_use_scale_shift_norm=False, + time_embed_dropout=0.0, + unet_res_connect=True, + ): + super().__init__() + # initialization mode + self.attn_linear_init_mode = attn_linear_init_mode + self.ffn_linear_init_mode = ffn_linear_init_mode + self.ffn2_linear_init_mode = ffn2_linear_init_mode + self.attn_proj_linear_init_mode = attn_proj_linear_init_mode + self.conv_init_mode = conv_init_mode + self.down_linear_init_mode = down_linear_init_mode + self.global_proj_linear_init_mode = global_proj_linear_init_mode + self.norm_init_mode = norm_init_mode + + self.input_shape = input_shape + self.target_shape = target_shape + self.num_blocks = len(depth) + self.depth = depth + self.base_units = base_units + self.scale_alpha = scale_alpha + self.downsample = downsample + self.downsample_type = downsample_type + self.upsample_type = upsample_type + self.upsample_kernel_size = upsample_kernel_size + if not isinstance(downsample, (tuple, list)): + downsample = (1, downsample, downsample) + if block_units is None: + block_units = [ + round_to(base_units * int((max(downsample) ** scale_alpha) ** i), 4) + for i in range(self.num_blocks) + ] + else: + assert len(block_units) == self.num_blocks and block_units[0] == base_units + self.block_units = block_units + self.hierarchical_pos_embed = hierarchical_pos_embed + self.num_global_vectors = num_global_vectors + use_global_vector = num_global_vectors > 0 + self.use_global_vector = use_global_vector + if global_dim_ratio != 1: + assert ( + separate_global_qkv is True + ), f"Setting global_dim_ratio != 1 requires separate_global_qkv == True." + self.global_dim_ratio = global_dim_ratio + self.use_global_vector_ffn = use_global_vector_ffn + + self.time_embed_channels_mult = time_embed_channels_mult + self.time_embed_channels = self.block_units[0] * time_embed_channels_mult + self.time_embed_use_scale_shift_norm = time_embed_use_scale_shift_norm + self.time_embed_dropout = time_embed_dropout + self.unet_res_connect = unet_res_connect + + if self.use_global_vector: + self.init_global_vectors = Parameter( + ops.zeros((self.num_global_vectors, global_dim_ratio * base_units)) + ) + + t_in, h_in, w_in, c_in = input_shape + t_out, h_out, w_out, c_out = target_shape + assert h_in == h_out and w_in == w_out and c_in == c_out + self.t_in = t_in + self.t_out = t_out + self.first_proj = TimeEmbedResBlock( + channels=self.data_shape[-1], + emb_channels=None, + dropout=proj_drop, + out_channels=self.base_units, + use_conv=False, + use_embed=False, + use_scale_shift_norm=False, + dims=3, + up=False, + down=False, + ) + self.pos_embed = PosEmbed( + embed_dim=base_units, + max_t=self.data_shape[0], + max_h=h_in, + max_w=w_in, + ) + + # diffusion time embed + self.time_embed = TimeEmbedLayer( + base_channels=self.block_units[0], + time_embed_channels=self.time_embed_channels, + ) + # # inner U-Net + if self.num_blocks > 1: + # Construct downsampling layers + if downsample_type == "patch_merge": + self.downsample_layers = nn.CellList( + [ + PatchMerging3D( + dim=self.block_units[i], + downsample=downsample, + padding_type=padding_type, + out_dim=self.block_units[i + 1], + linear_init_mode=down_linear_init_mode, + norm_init_mode=norm_init_mode, + ) + for i in range(self.num_blocks - 1) + ] + ) + else: + raise NotImplementedError + if self.use_global_vector: + self.down_layer_global_proj = nn.CellList( + [ + nn.Dense( + in_channels=global_dim_ratio * self.block_units[i], + out_channels=global_dim_ratio * self.block_units[i + 1], + ) + for i in range(self.num_blocks - 1) + ] + ) + # Construct upsampling layers + if self.upsample_type == "upsample": + self.upsample_layers = nn.CellList( + [ + Upsample3DLayer( + dim=self.mem_shapes[i + 1][-1], + out_dim=self.mem_shapes[i][-1], + target_size=self.mem_shapes[i][:3], + kernel_size=upsample_kernel_size, + conv_init_mode=conv_init_mode, + ) + for i in range(self.num_blocks - 1) + ] + ) + else: + raise NotImplementedError + if self.use_global_vector: + self.up_layer_global_proj = nn.CellList( + [ + nn.Dense( + in_channels=global_dim_ratio * self.block_units[i + 1], + out_channels=global_dim_ratio * self.block_units[i], + ) + for i in range(self.num_blocks - 1) + ] + ) + if self.hierarchical_pos_embed: + self.down_hierarchical_pos_embed_l = nn.CellList( + [ + PosEmbed( + embed_dim=self.block_units[i], + max_t=self.mem_shapes[i][0], + max_h=self.mem_shapes[i][1], + max_w=self.mem_shapes[i][2], + ) + for i in range(self.num_blocks - 1) + ] + ) + self.up_hierarchical_pos_embed_l = nn.CellList( + [ + PosEmbed( + embed_dim=self.block_units[i], + max_t=self.mem_shapes[i][0], + max_h=self.mem_shapes[i][1], + max_w=self.mem_shapes[i][2], + ) + for i in range(self.num_blocks - 1) + ] + ) + + if use_attn_pattern: + block_attn_patterns = self.depth + block_cuboid_size = [] + block_cuboid_strategy = [] + block_cuboid_shift_size = [] + for idx, _ in enumerate(block_attn_patterns): + cuboid_size, strategy, shift_size = self_axial(self.mem_shapes[idx]) + block_cuboid_size.append(cuboid_size) + block_cuboid_strategy.append(strategy) + block_cuboid_shift_size.append(shift_size) + else: + if not isinstance(block_cuboid_size[0][0], (list, tuple)): + block_cuboid_size = [block_cuboid_size for _ in range(self.num_blocks)] + else: + assert ( + len(block_cuboid_size) == self.num_blocks + ), f"Incorrect input format! Received block_cuboid_size={block_cuboid_size}" + + if not isinstance(block_cuboid_strategy[0][0], (list, tuple)): + block_cuboid_strategy = [ + block_cuboid_strategy for _ in range(self.num_blocks) + ] + else: + assert ( + len(block_cuboid_strategy) == self.num_blocks + ), f"Incorrect input format! Received block_strategy={block_cuboid_strategy}" + + if not isinstance(block_cuboid_shift_size[0][0], (list, tuple)): + block_cuboid_shift_size = [ + block_cuboid_shift_size for _ in range(self.num_blocks) + ] + else: + assert ( + len(block_cuboid_shift_size) == self.num_blocks + ), f"Incorrect input format! Received block_shift_size={block_cuboid_shift_size}" + self.block_cuboid_size = block_cuboid_size + self.block_cuboid_strategy = block_cuboid_strategy + self.block_cuboid_shift_size = block_cuboid_shift_size + + # cuboid self attention blocks + down_self_blocks = [] + up_self_blocks = [] + # ResBlocks that incorporate `time_embed` + down_time_embed_blocks = [] + up_time_embed_blocks = [] + for i in range(self.num_blocks): + down_time_embed_blocks.append( + TimeEmbedResBlock( + channels=self.mem_shapes[i][-1], + emb_channels=self.time_embed_channels, + dropout=self.time_embed_dropout, + out_channels=self.mem_shapes[i][-1], + use_conv=False, + use_embed=True, + use_scale_shift_norm=self.time_embed_use_scale_shift_norm, + dims=3, + up=False, + down=False, + ) + ) + + ele_depth = depth[i] + stack_cuboid_blocks = [ + StackCuboidSelfAttentionBlock( + dim=self.mem_shapes[i][-1], + num_heads=num_heads, + block_cuboid_size=block_cuboid_size[i], + block_strategy=block_cuboid_strategy[i], + block_shift_size=block_cuboid_shift_size[i], + attn_drop=attn_drop, + proj_drop=proj_drop, + ffn_drop=ffn_drop, + activation=ffn_activation, + gated_ffn=gated_ffn, + norm_layer=norm_layer, + use_inter_ffn=use_inter_ffn, + padding_type=padding_type, + use_global_vector=use_global_vector, + use_global_vector_ffn=use_global_vector_ffn, + use_global_self_attn=use_global_self_attn, + separate_global_qkv=separate_global_qkv, + global_dim_ratio=global_dim_ratio, + use_relative_pos=use_relative_pos, + use_final_proj=self_attn_use_final_proj, + # initialization + attn_linear_init_mode=attn_linear_init_mode, + ffn_linear_init_mode=ffn_linear_init_mode, + ffn2_linear_init_mode=ffn2_linear_init_mode, + attn_proj_linear_init_mode=attn_proj_linear_init_mode, + norm_init_mode=norm_init_mode, + ) + for _ in range(ele_depth) + ] + down_self_blocks.append(nn.CellList(stack_cuboid_blocks)) + + up_time_embed_blocks.append( + TimeEmbedResBlock( + channels=self.mem_shapes[i][-1], + emb_channels=self.time_embed_channels, + dropout=self.time_embed_dropout, + out_channels=self.mem_shapes[i][-1], + use_conv=False, + use_embed=True, + use_scale_shift_norm=self.time_embed_use_scale_shift_norm, + dims=3, + up=False, + down=False, + ) + ) + + stack_cuboid_blocks = [ + StackCuboidSelfAttentionBlock( + dim=self.mem_shapes[i][-1], + num_heads=num_heads, + block_cuboid_size=block_cuboid_size[i], + block_strategy=block_cuboid_strategy[i], + block_shift_size=block_cuboid_shift_size[i], + attn_drop=attn_drop, + proj_drop=proj_drop, + ffn_drop=ffn_drop, + activation=ffn_activation, + gated_ffn=gated_ffn, + norm_layer=norm_layer, + use_inter_ffn=use_inter_ffn, + padding_type=padding_type, + use_global_vector=use_global_vector, + use_global_vector_ffn=use_global_vector_ffn, + use_global_self_attn=use_global_self_attn, + separate_global_qkv=separate_global_qkv, + global_dim_ratio=global_dim_ratio, + use_relative_pos=use_relative_pos, + use_final_proj=self_attn_use_final_proj, + # initialization + attn_linear_init_mode=attn_linear_init_mode, + ffn_linear_init_mode=ffn_linear_init_mode, + ffn2_linear_init_mode=ffn2_linear_init_mode, + attn_proj_linear_init_mode=attn_proj_linear_init_mode, + norm_init_mode=norm_init_mode, + ) + for _ in range(ele_depth) + ] + up_self_blocks.append(nn.CellList(stack_cuboid_blocks)) + self.down_self_blocks = nn.CellList(down_self_blocks) + self.up_self_blocks = nn.CellList(up_self_blocks) + self.down_time_embed_blocks = nn.CellList(down_time_embed_blocks) + self.up_time_embed_blocks = nn.CellList(up_time_embed_blocks) + self.final_proj = nn.Dense(self.base_units, c_out) + + self.reset_parameters() + + def reset_parameters(self): + '''init parameters''' + if self.num_global_vectors > 0: + initializer.TruncatedNormal(self.init_global_vectors, sigma=0.02) + self.first_proj.reset_parameters() + apply_initialization(self.final_proj, linear_mode="2") + self.pos_embed.reset_parameters() + for block in self.down_self_blocks: + for m in block: + m.reset_parameters() + for m in self.down_time_embed_blocks: + m.reset_parameters() + for block in self.up_self_blocks: + for m in block: + m.reset_parameters() + for m in self.up_time_embed_blocks: + m.reset_parameters() + if self.num_blocks > 1: + for m in self.downsample_layers: + m.reset_parameters() + for m in self.upsample_layers: + m.reset_parameters() + if self.use_global_vector: + apply_initialization( + self.down_layer_global_proj, + linear_mode=self.global_proj_linear_init_mode, + ) + apply_initialization( + self.up_layer_global_proj, + linear_mode=self.global_proj_linear_init_mode, + ) + if self.hierarchical_pos_embed: + for m in self.down_hierarchical_pos_embed_l: + m.reset_parameters() + for m in self.up_hierarchical_pos_embed_l: + m.reset_parameters() + + @property + def data_shape(self): + '''set datashape''' + if not hasattr(self, "_data_shape"): + t_in, h_in, w_in, c_in = self.input_shape + t_out, h_out, w_out, c_out = self.target_shape + assert h_in == h_out and w_in == w_out and c_in == c_out + self._data_shape = ( + t_in + t_out, + h_in, + w_in, + c_in + 1, + ) + return self._data_shape + + @property + def mem_shapes(self): + """Get the shape of the output memory based on the input shape. This can be used for constructing the decoder. + + Returns + ------- + mem_shapes + A list of shapes of the output memory + """ + inner_data_shape = tuple(self.data_shape)[:3] + (self.base_units,) + if self.num_blocks == 1: + return [inner_data_shape] + mem_shapes = [inner_data_shape] + curr_shape = inner_data_shape + for down_layer in self.downsample_layers: + curr_shape = down_layer.get_out_shape(curr_shape) + mem_shapes.append(curr_shape) + return mem_shapes + + def construct(self, x, t, cond): + """ + + Parameters + ---------- + x: mindspore.Tensor + Shape (B, t_out, H, W, C) + t: mindspore.Tensor + Shape (B, ) + cond: mindspore.Tensor + Shape (B, t_in, H, W, C) + verbose: bool + + Returns + ------- + out: mindspore.Tensor + Shape (B, T, H, W, C) + """ + + x = ops.cat([cond, x], axis=1) + obs_indicator = ops.ones_like(x[..., :1]) + obs_indicator[:, self.t_in :, ...] = 0.0 + x = ops.cat([x, obs_indicator], axis=-1) + x = x.transpose((0, 4, 1, 2, 3)) + x = self.first_proj(x) + x = x.transpose((0, 2, 3, 4, 1)) + x = self.pos_embed(x) + # inner U-Net + t_emb = self.time_embed(timestep_embedding(t, self.block_units[0])) + if self.unet_res_connect: + res_connect_l = [] + for i in range(self.num_blocks): + # Downample + if i > 0: + x = self.downsample_layers[i - 1](x) + for idx in range(self.depth[i]): + x = x.transpose((0, 4, 1, 2, 3)) + x = self.down_time_embed_blocks[i](x, t_emb) + x = x.transpose((0, 2, 3, 4, 1)) + x = self.down_self_blocks[i][idx](x) + if self.unet_res_connect and i < self.num_blocks - 1: + res_connect_l.append(x) + + for i in range(self.num_blocks - 1, -1, -1): + if self.unet_res_connect and i < self.num_blocks - 1: + x = x + res_connect_l[i] + for idx in range(self.depth[i]): + x = x.transpose((0, 4, 1, 2, 3)) + x = self.up_time_embed_blocks[i](x, t_emb) + x = x.transpose((0, 2, 3, 4, 1)) + x = self.up_self_blocks[i][idx](x) + if i > 0: + x = self.upsample_layers[i - 1](x) + x = self.final_proj(x[:, self.t_in :, ...]) + return x diff --git a/MindEarth/applications/nowcasting/PreDiff/src/diffusion/forecast.py b/MindEarth/applications/nowcasting/PreDiff/src/diffusion/forecast.py new file mode 100644 index 0000000000000000000000000000000000000000..1873274edf42418f53a6f234c22b04afed7eb53d --- /dev/null +++ b/MindEarth/applications/nowcasting/PreDiff/src/diffusion/forecast.py @@ -0,0 +1,285 @@ +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +'''diffusion inferrence''' +import time +import os +import json +from typing import Sequence, Union +import numpy as np +from einops import rearrange + +import mindspore as ms +from mindspore import ops, mint, nn + +from src.visual import vis_sevir_seq +from src.sevir_dataset import SEVIRDataset + + +def get_alignment_kwargs_avg_x(target_seq): + """Generate alignment parameters for guided sampling""" + batch_size = target_seq.shape[0] + avg_intensity = mint.mean(target_seq.view(batch_size, -1), dim=1, keepdim=True) + return {"avg_x_gt": avg_intensity * 2.0} + + +class DiffusionInferrence(nn.Cell): + """ + Class managing model inference and evaluation processes. Handles loading checkpoints, + generating predictions, calculating evaluation metrics, and saving visualization results. + """ + def __init__(self, main_module, dm, logger, config): + """ + Initialize inference manager with model, data module, logger, and configuration. + Args: + main_module: Main diffusion model for inference + dm: Data module providing test dataset + logger: Logging utility for evaluation progress + config: Configuration dictionary containing evaluation parameters + """ + super().__init__() + self.ckpt_path = config["summary"].get( + "ckpt_path", + "./ckpt/diffusion.ckpt", + ) # 删除 + self.num_samples = config["eval"].get("num_samples_per_context", 1) + self.eval_example_only = config["eval"].get("eval_example_only", True) + self.alignment_type = ( + config.get("model", {}).get("align", {}).get("alignment_type", "avg_x") + ) + self.use_alignment = self.alignment_type is not None + self.eval_aligned = config["eval"].get("eval_aligned", True) + self.eval_unaligned = config["eval"].get("eval_unaligned", True) + self.num_samples_per_context = config["eval"].get("num_samples_per_context", 1) + self.logging_prefix = config["logging"].get("logging_prefix", "PreDiff") + self.test_example_data_idx_list = config["eval"].get( + "test_example_data_idx_list", [0, 16, 32, 48, 64, 72, 96, 108, 128] + ) + self.main_module = main_module + self.testdataset = dm.sevir_test + self.logger = logger + self.datasetprocessing = SEVIRDataset( + data_types=["vil"], + layout="NHWT", + rescale_method=config.get("rescale_method", "01"), + ) + self.example_save_dir = config["summary"].get("summary_dir", "./summary") + + self.fs = config["eval"].get("fs", 20) + self.label_offset = config["eval"].get("label_offset", [-0.5, 0.5]) + self.label_avg_int = config["eval"].get("label_avg_int", False) + + self.current_epoch = 0 + + self.learn_logvar = ( + config.get("model", {}).get("diffusion", {}).get("learn_logvar", False) + ) + self.logvar = main_module.logvar + self.maeloss = nn.MAELoss() + self.test_metrics = { + "step": 0, + "mse": 0.0, + "mae": 0.0, + "ssim": 0.0, + "mse_kc": 0.0, + "mae_kc": 0.0, + } + + def test(self): + """Execute complete evaluation pipeline.""" + self.logger.info("============== Start Test ==============") + self.start_time = time.time() + for batch_idx, item in enumerate(self.testdataset.create_dict_iterator()): + self.test_metrics = self._test_onestep(item, batch_idx, self.test_metrics) + + self._finalize_test(self.test_metrics) + + def _test_onestep(self, item, batch_idx, metrics): + """Process one test batch and update evaluation metrics.""" + data_idx = int(batch_idx * 2) + if not self._should_test_onestep(data_idx): + return metrics + data = item.get("vil") + data = self.datasetprocessing.process_data(data) + target_seq, cond, context_seq = self._get_model_inputs(data) + aligned_preds, unaligned_preds = self._generate_predictions( + cond, target_seq + ) + metrics = self._update_metrics( + aligned_preds, unaligned_preds, target_seq, metrics + ) + self._plt_pred( + data_idx, + context_seq, + target_seq, + aligned_preds, + unaligned_preds, + metrics["step"], + ) + + metrics["step"] += 1 + return metrics + + def _should_test_onestep(self, data_idx): + """Determine if evaluation should be performed on current data index.""" + return (not self.eval_example_only) or ( + data_idx in self.test_example_data_idx_list + ) + + def _get_model_inputs(self, data): + """Extract and prepare model inputs from raw data.""" + target_seq, cond, context_seq = self.main_module.get_input( + data, return_verbose=True + ) + return target_seq, cond, context_seq + + def _generate_predictions(self, cond, target_seq): + """Generate both aligned and unaligned predictions from the model.""" + aligned_preds = [] + unaligned_preds = [] + + for _ in range(self.num_samples_per_context): + if self.use_alignment and self.eval_aligned: + aligned_pred = self._sample_with_alignment( + cond, target_seq + ) + aligned_preds.append(aligned_pred) + + if self.eval_unaligned: + unaligned_pred = self._sample_without_alignment(cond) + unaligned_preds.append(unaligned_pred) + + return aligned_preds, unaligned_preds + + def _sample_with_alignment(self, cond, target_seq): + """Generate predictions using alignment mechanism.""" + alignment_kwargs = get_alignment_kwargs_avg_x(target_seq) + pred_seq = self.main_module.sample( + cond=cond, + batch_size=cond["y"].shape[0], + return_intermediates=False, + use_alignment=True, + alignment_kwargs=alignment_kwargs, + verbose=False, + ) + if pred_seq.dtype != ms.float32: + pred_seq = pred_seq.float() + return pred_seq + + def _sample_without_alignment(self, cond): + """Generate predictions without alignment.""" + pred_seq = self.main_module.sample( + cond=cond, + batch_size=cond["y"].shape[0], + return_intermediates=False, + verbose=False, + ) + if pred_seq.dtype != ms.float32: + pred_seq = pred_seq.float() + return pred_seq + + def _update_metrics(self, aligned_preds, unaligned_preds, target_seq, metrics): + """Update evaluation metrics with new predictions.""" + for pred in aligned_preds: + metrics["mse_kc"] += ops.mse_loss(pred, target_seq) + metrics["mae_kc"] += self.maeloss(pred, target_seq) + self.main_module.test_aligned_score.update(pred, target_seq) + + for pred in unaligned_preds: + metrics["mse"] += ops.mse_loss(pred, target_seq) + metrics["mae"] += self.maeloss(pred, target_seq) + self.main_module.test_score.update(pred, target_seq) + + pred_bchw = self._convert_to_bchw(pred) + target_bchw = self._convert_to_bchw(target_seq) + metrics["ssim"] += self.main_module.test_ssim(pred_bchw, target_bchw)[0] + + return metrics + + def _convert_to_bchw(self, tensor): + """Convert tensor to batch-channel-height-width format for metrics.""" + return rearrange(tensor.asnumpy(), "b t h w c -> (b t) c h w") + + def _plt_pred( + self, data_idx, context_seq, target_seq, aligned_preds, unaligned_preds, step + ): + """Generate and save visualization of predictions.""" + pred_sequences = [pred[0].asnumpy() for pred in aligned_preds + unaligned_preds] + pred_labels = [ + f"{self.logging_prefix}_aligned_pred_{i}" for i in range(len(aligned_preds)) + ] + [f"{self.logging_prefix}_pred_{i}" for i in range(len(unaligned_preds))] + + self.save_vis_step_end( + data_idx=data_idx, + context_seq=context_seq[0].asnumpy(), + target_seq=target_seq[0].asnumpy(), + pred_seq=pred_sequences, + pred_label=pred_labels, + mode="test", + suffix=f"_step_{step}", + ) + + def _finalize_test(self, metrics): + """Complete test process and log final metrics.""" + total_time = (time.time() - self.start_time) * 1000 + self.logger.info(f"test cost: {total_time:.2f} ms") + self._compute_total_metrics(metrics) + self.logger.info("============== Test Completed ==============") + + def _compute_total_metrics(self, metrics): + """log_metrics""" + step_count = max(metrics["step"], 1) + if self.eval_unaligned: + self.logger.info(f"MSE: {metrics['mse'] / step_count}") + self.logger.info(f"MAE: {metrics['mae'] / step_count}") + self.logger.info(f"SSIM: {metrics['ssim'] / step_count}") + test_score = self.main_module.test_score.eval() + self.logger.info("SCORE:\n%s", json.dumps(test_score, indent=4)) + if self.use_alignment: + self.logger.info(f"KC_MSE: {metrics['mse_kc'] / step_count}") + self.logger.info(f"KC_MAE: {metrics['mae_kc'] / step_count}") + aligned_score = self.main_module.test_aligned_score.eval() + self.logger.info("KC_SCORE:\n%s", json.dumps(aligned_score, indent=4)) + + def save_vis_step_end( + self, + data_idx: int, + context_seq: np.ndarray, + target_seq: np.ndarray, + pred_seq: Union[np.ndarray, Sequence[np.ndarray]], + pred_label: Union[str, Sequence[str]] = None, + mode: str = "train", + prefix: str = "", + suffix: str = "", + ): + """Save visualization of predictions with context and target.""" + example_data_idx_list = self.test_example_data_idx_list + if isinstance(pred_seq, Sequence): + seq_list = [context_seq, target_seq] + list(pred_seq) + label_list = ["context", "target"] + pred_label + else: + seq_list = [context_seq, target_seq, pred_seq] + label_list = ["context", "target", pred_label] + if data_idx in example_data_idx_list: + png_save_name = f"{prefix}{mode}_data_{data_idx}{suffix}.png" + vis_sevir_seq( + save_path=os.path.join(self.example_save_dir, png_save_name), + seq=seq_list, + label=label_list, + interval_real_time=10, + plot_stride=1, + fs=self.fs, + label_offset=self.label_offset, + label_avg_int=self.label_avg_int, + ) diff --git a/MindEarth/applications/nowcasting/PreDiff/src/diffusion/latent_diffusion.py b/MindEarth/applications/nowcasting/PreDiff/src/diffusion/latent_diffusion.py new file mode 100644 index 0000000000000000000000000000000000000000..1454a67a0b2ec0174fe110309dda2ed40c78819f --- /dev/null +++ b/MindEarth/applications/nowcasting/PreDiff/src/diffusion/latent_diffusion.py @@ -0,0 +1,1076 @@ +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"Latent Diffusion Model" +import warnings +from typing import Sequence, Dict, Any, Callable +from copy import deepcopy +from functools import partial +import numpy as np +from tqdm import tqdm +from einops import rearrange +from omegaconf import OmegaConf + +import mindspore as ms +from mindspore import nn, ops, Tensor, Parameter, mint + +from src.utils import ( + DiagonalGaussianDistribution, + make_beta_schedule, + extract_into_tensor, + noise_like, + default, + parse_layout_shape, + disabled_train, + layout_to_in_out_slice, + calculate_ssim, + SEVIRSkillScore, +) +from src.sevir_dataset import SEVIRDataModule +from src.vae import AutoencoderKL +from src.knowledge_alignment.alignment_net import AvgIntensityAlignment +from .cuboid_transformer_unet import CuboidTransformerUNet + + +class LatentDiffusion(nn.Cell): + """ + Base class for latent space diffusion models. Implements core diffusion processes including + noise scheduling, model application, loss calculation, and latent space operations. Integrates + main UNet model, VAE, and conditioning modules with support for temporal alignment. + """ + + def __init__( + self, + main_model: nn.Cell, + layout: str = "NTHWC", + data_shape: Sequence[int] = (10, 128, 128, 4), + timesteps=1000, + beta_schedule="linear", + loss_type="l2", + monitor="val/loss", + log_every_t=100, + clip_denoised=False, + linear_start=1e-4, + linear_end=2e-2, + cosine_s=8e-3, + given_betas=None, + original_elbo_weight=0.0, + v_posterior=0.0, + l_simple_weight=1.0, + learn_logvar=False, + logvar_init=0.0, + latent_shape: Sequence[int] = (10, 16, 16, 4), + first_stage_model: nn.Cell = None, + cond_stage_forward=None, + scale_by_std=False, + scale_factor=1.0, + ): + super().__init__() + + self.clip_denoised = clip_denoised + self.log_every_t = log_every_t + self.main_model = main_model + self.layout = layout + self.data_shape = data_shape + self.parse_layout_shape(layout=layout) + self.v_posterior = v_posterior + self.original_elbo_weight = original_elbo_weight + self.l_simple_weight = l_simple_weight + + if monitor is not None: + self.monitor = monitor + + self.register_schedule( + given_betas=given_betas, + beta_schedule=beta_schedule, + timesteps=timesteps, + linear_start=linear_start, + linear_end=linear_end, + cosine_s=cosine_s, + ) + + self.loss_type = loss_type + + self.learn_logvar = learn_logvar + logvar = ops.full(fill_value=logvar_init, size=(self.num_timesteps,)).astype( + ms.float32 + ) + if self.learn_logvar: + self.logvar = Parameter(logvar, requires_grad=True) + else: + self.logvar = Parameter(logvar, name="logvar", requires_grad=False) + + self.latent_shape = latent_shape + self.scale_by_std = scale_by_std + if not scale_by_std: + self.scale_factor = scale_factor + else: + self.logvar = Parameter( + scale_factor, name="scale_factor", requires_grad=False + ) + + self.instantiate_first_stage(first_stage_model) + self.instantiate_cond_stage(cond_stage_forward) + + def set_alignment(self, alignment_fn: Callable = None): + """ + Sets alignment function for denoising process after initialization. + Args: + alignment_fn (Callable): Alignment function with signature + `alignment_fn(zt, t, zc=None, y=None, **kwargs)` + """ + self.alignment_fn = alignment_fn + + def parse_layout_shape(self, layout): + """ + Parses data layout string to determine axis indices. + Args: + layout (str): Data layout specification (e.g., 'NTHWC') + """ + parsed_dict = parse_layout_shape(layout=layout) + self.batch_axis = parsed_dict["batch_axis"] + self.t_axis = parsed_dict["t_axis"] + self.h_axis = parsed_dict["h_axis"] + self.w_axis = parsed_dict["w_axis"] + self.c_axis = parsed_dict["c_axis"] + self.all_slice = [ + slice(None, None), + ] * len(layout) + + def extract_into_tensor(self, a, t, x_shape): + """Extracts schedule parameters into tensor format for current batch.""" + return extract_into_tensor( + a=a, t=t, x_shape=x_shape, batch_axis=self.batch_axis + ) + + @property + def loss_mean_dim(self): + """Computes mean dimensions for loss calculation excluding batch axis.""" + if not hasattr(self, "loss_m_dim"): + loss_m_dim = list(range(len(self.layout))) + loss_m_dim.pop(self.batch_axis) + self.loss_m_dim = tuple(loss_m_dim) + return self.loss_m_dim + + def get_batch_latent_shape(self, batch_size=1): + """ + Generates latent shape with specified batch size. + Args: + batch_size (int): Desired batch size + """ + batch_latent_shape = deepcopy(list(self.latent_shape)) + batch_latent_shape.insert(self.batch_axis, batch_size) + self.batch_latent_shape = tuple(batch_latent_shape) + return self.batch_latent_shape + + def register_schedule( + self, + given_betas=None, + beta_schedule="linear", + timesteps=1000, + linear_start=1e-4, + linear_end=2e-2, + cosine_s=8e-3, + ): + """ + Registers diffusion schedule parameters and precomputes necessary tensors. + Args: + given_betas (Tensor): Custom beta values + beta_schedule (str): Schedule type ('linear', 'cosine') + timesteps (int): Number of diffusion steps + linear_start (float): Linear schedule start value + linear_end (float): Linear schedule end value + cosine_s (float): Cosine schedule parameter + """ + if given_betas is not None: + betas = given_betas + else: + betas = make_beta_schedule( + beta_schedule, + timesteps, + linear_start=linear_start, + linear_end=linear_end, + cosine_s=cosine_s, + ) + alphas = 1.0 - betas + alphas_cumprod = np.cumprod(alphas, axis=0) + alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1]) + + (timesteps,) = betas.shape + self.num_timesteps = int(timesteps) + self.linear_start = linear_start + self.linear_end = linear_end + assert ( + alphas_cumprod.shape[0] == self.num_timesteps + ), "alphas have to be defined for each timestep" + + to_mindspore = partial(Tensor, dtype=ms.float32) + self.betas = Parameter(to_mindspore(betas), name="betas", requires_grad=False) + self.alphas_cumprod = Parameter( + to_mindspore(alphas_cumprod), name="alphas_cumprod", requires_grad=False + ) + self.alphas_cumprod_prev = Parameter( + to_mindspore(alphas_cumprod_prev), + name="alphas_cumprod_prev", + requires_grad=False, + ) + self.sqrt_alphas_cumprod = Parameter( + to_mindspore(np.sqrt(alphas_cumprod)), + name="sqrt_alphas_cumprod", + requires_grad=False, + ) + self.sqrt_one_minus_alphas_cumprod = Parameter( + to_mindspore(np.sqrt(1.0 - alphas_cumprod)), + name="sqrt_one_minus_alphas_cumprod", + requires_grad=False, + ) + self.log_one_minus_alphas_cumprod = Parameter( + to_mindspore(np.log(1.0 - alphas_cumprod)), + name="log_one_minus_alphas_cumprod", + requires_grad=False, + ) + self.sqrt_recip_alphas_cumprod = Parameter( + to_mindspore(np.sqrt(1.0 / alphas_cumprod)), + name="sqrt_recip_alphas_cumprod", + requires_grad=False, + ) + self.sqrt_recipm1_alphas_cumprod = Parameter( + to_mindspore(np.sqrt(1.0 / alphas_cumprod - 1)), + name="sqrt_recipm1_alphas_cumprod", + requires_grad=False, + ) + + posterior_variance = (1 - self.v_posterior) * betas * ( + 1.0 - alphas_cumprod_prev + ) / (1.0 - alphas_cumprod) + self.v_posterior * betas + self.posterior_variance = Parameter( + to_mindspore(posterior_variance), + name="posterior_variance", + requires_grad=False, + ) + self.posterior_log_variance_clipped = Parameter( + to_mindspore(np.log(np.maximum(posterior_variance, 1e-20))), + name="posterior_log_variance_clipped", + requires_grad=False, + ) + self.posterior_mean_coef1 = Parameter( + to_mindspore(betas * np.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod)), + name="posterior_mean_coef1", + requires_grad=False, + ) + self.posterior_mean_coef2 = Parameter( + to_mindspore( + (1.0 - alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - alphas_cumprod) + ), + name="posterior_mean_coef2", + requires_grad=False, + ) + + lvlb_weights = self.betas**2 / ( + 2 + * self.posterior_variance + * to_mindspore(alphas) + * (1 - self.alphas_cumprod) + ) + lvlb_weights[0] = lvlb_weights[1] + self.lvlb_weights = Parameter( + lvlb_weights, name="lvlb_weights", requires_grad=False + ) + assert not ops.isnan(self.lvlb_weights).all() + + def instantiate_first_stage(self, first_stage_model): + """ + Initializes and freezes the first stage autoencoder model. + Args: + first_stage_model (nn.Cell): Autoencoder model instance + """ + if isinstance(first_stage_model, nn.Cell): + model = first_stage_model + else: + assert first_stage_model is None + raise NotImplementedError("No default first_stage_model supported yet!") + self.first_stage_model = model.set_train(False) + self.first_stage_model.train = disabled_train + for param in self.first_stage_model.trainable_params(): + param.requires_grad = False + + def instantiate_cond_stage(self, cond_stage_forward): + """Configures conditioning stage encoder with spatial rearrangement.""" + self.cond_stage_model = self.first_stage_model + for param in self.cond_stage_model.trainable_params(): + param.requires_grad = False + cond_stage_forward = self.cond_stage_model.encode + + def wrapper(cond_stage_forward: Callable): + def func(c: Dict[str, Any]): + c = c.get("y") + batch_size = c.shape[self.batch_axis] + c = c.transpose(0, 1, 4, 2, 3) + n_new, t_new, c_new, h_new, w_new = c.shape + c = c.reshape(n_new * t_new, c_new, h_new, w_new) + c = cond_stage_forward(c) + if isinstance(c, DiagonalGaussianDistribution): + c = c.mode() + n_new, c_new, h_new, w_new = c.shape + c = c.reshape(batch_size, -1, c_new, h_new, w_new) + c = c.transpose(0, 1, 3, 4, 2) + return c + + return func + + self.cond_stage_forward = wrapper(cond_stage_forward) + + def get_first_stage_encoding(self, encoder_posterior): + """ + Extracts latent representation from encoder output. + Args: + encoder_posterior (Tensor/DiagonalGaussianDistribution): Encoder output + """ + if isinstance(encoder_posterior, DiagonalGaussianDistribution): + z = encoder_posterior.sample() + elif isinstance(encoder_posterior, Tensor): + z = encoder_posterior + else: + raise NotImplementedError( + f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented" + ) + return self.scale_factor * z + + @property + def einops_layout(self): + """Returns Einops layout string for data rearrangement.""" + return " ".join(self.layout) + + @property + def einops_spatial_layout(self): + """Generates spatial Einops pattern for 2D/3D data handling.""" + if not hasattr(self, "_einops_spatial_layout"): + assert len(self.layout) == 4 or len(self.layout) == 5 + self._einops_spatial_layout = ( + "(N T) C H W" if self.layout.find("T") else "N C H W" + ) + return self._einops_spatial_layout + + def decode_first_stage(self, z): + """ + Decodes latent representation to data space with spatial rearrangement. + Args: + z (Tensor): Latent tensor + """ + z = 1.0 / self.scale_factor * z + batch_size = z.shape[self.batch_axis] + z = rearrange( + z.asnumpy(), f"{self.einops_layout} -> {self.einops_spatial_layout}" + ) + z = Tensor.from_numpy(z) + output = self.first_stage_model.decode(z) + output = rearrange( + output.asnumpy(), + f"{self.einops_spatial_layout} -> {self.einops_layout}", + N=batch_size, + ) + output = Tensor.from_numpy(output) + return output + + def encode_first_stage(self, x): + """ + Encodes input data into latent space. + Args: + x (Tensor): Input data tensor + """ + encoder_posterior = self.first_stage_model.encode(x) + output = self.get_first_stage_encoding(encoder_posterior) + return output + + def apply_model(self, x_noisy, t, cond): + """ + Applies main UNet model to denoise inputs. + Args: + x_noisy (Tensor): Noisy input tensor + t (Tensor): Time step tensor + cond (Dict): Conditioning information + Returns: + Tensor: Denoising model output + """ + x_recon = self.main_model(x_noisy, t, cond) + if isinstance(x_recon, tuple): + return x_recon[0] + return x_recon + + def q_sample(self, x_start, t, noise=None): + """ + Adds noise to clean data according to diffusion schedule. + Args: + x_start (Tensor): Clean data tensor + t (Tensor): Time step tensor + noise (Tensor): Optional noise tensor + Returns: + Tensor: Noisy data tensor + """ + noise = default(noise, lambda: ops.randn_like(x_start)) + return ( + self.extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) + * x_start + + self.extract_into_tensor( + self.sqrt_one_minus_alphas_cumprod, t, x_start.shape + ) + * noise + ) + + def get_loss(self, pred, target, mean=True): + """ + Calculates loss between prediction and target. + Args: + pred (Tensor): Model predictions + target (Tensor): Target values + mean (bool): Whether to return mean loss + Returns: + Tensor: Loss value(s) + """ + if self.loss_type == "l1": + loss = (target - pred).abs() + if mean: + loss = loss.mean() + elif self.loss_type == "l2": + if mean: + loss = mint.nn.functional.mse_loss(target, pred) + else: + loss = mint.nn.functional.mse_loss(target, pred, reduction="none") + else: + raise NotImplementedError("unknown loss type '{loss_type}'") + + return loss + + def p_losses(self, x_start, cond, t, noise=None): + """ + Computes diffusion training loss for given time steps. + Args: + x_start (Tensor): Clean data tensor + cond (Dict): Conditioning information + t (Tensor): Time step tensor + noise (Tensor): Optional noise tensor + Returns: + Tensor: Total training loss + """ + noise = default(noise, lambda: ops.randn_like(x_start)) + x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) + model_output = self.apply_model(x_noisy, t, cond) + loss_simple = self.get_loss(model_output, noise, mean=False).mean( + axis=self.loss_mean_dim + ) + + logvar_t = self.logvar[t] + + loss = loss_simple / ops.exp(logvar_t) + logvar_t + + loss = self.l_simple_weight * loss.mean() + return loss + + def predict_start_from_noise(self, x_t, t, noise): + """ + Reconstructs clean data from noisy input and predicted noise. + Args: + x_t (Tensor): Noisy data tensor + t (Tensor): Time step tensor + noise (Tensor): Predicted noise tensor + Returns: + Tensor: Reconstructed clean data + """ + sqrt_recip_alphas_cumprod_t = self.extract_into_tensor( + self.sqrt_recip_alphas_cumprod, t, x_t.shape + ) + sqrt_recipm1_alphas_cumprod_t = self.extract_into_tensor( + self.sqrt_recipm1_alphas_cumprod, t, x_t.shape + ) + term1 = sqrt_recip_alphas_cumprod_t * x_t + term2 = sqrt_recipm1_alphas_cumprod_t * noise + pred = term1 - term2 + return pred + + def q_posterior(self, x_start, x_t, t): + """ + Calculates posterior distribution parameters for given time steps. + Args: + x_start (Tensor): Clean data tensor + x_t (Tensor): Noisy data tensor + t (Tensor): Time step tensor + Returns: + Tuple[Tensor]: (posterior mean, variance, log variance) + """ + posterior_mean = ( + self.extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start + + self.extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t + ) + posterior_variance = self.extract_into_tensor( + self.posterior_variance, t, x_t.shape + ) + posterior_log_variance_clipped = self.extract_into_tensor( + self.posterior_log_variance_clipped, t, x_t.shape + ) + return posterior_mean, posterior_variance, posterior_log_variance_clipped + + def p_mean_variance( + self, + zt, + zc, + t, + clip_denoised: bool, + return_x0=False, + score_corrector=None, + corrector_kwargs=None, + ): + """ + Computes predicted mean and variance during denoising. + Args: + zt (Tensor): Current latent sample + zc (Tensor): Conditioning tensor + t (Tensor): Time step tensor + clip_denoised (bool): Whether to clip denoised outputs + return_x0 (bool): Whether to return reconstructed x0 + score_corrector (Callable): Optional score correction function + corrector_kwargs (Dict): Correction function parameters + Returns: + Tuple[Tensor]: (mean, variance, log variance, [reconstructed x0]) + """ + t_in = t + model_out = self.apply_model(zt, t_in, zc) + if score_corrector is not None: + model_out = score_corrector.modify_score( + self, model_out, zt, t, zc, **corrector_kwargs + ) + z_recon = self.predict_start_from_noise(zt, t=t, noise=model_out) + if clip_denoised: + z_recon = z_recon.clamp(-1.0, 1.0) + model_mean, posterior_variance, posterior_log_variance = self.q_posterior( + x_start=z_recon, x_t=zt, t=t + ) + if return_x0: + return model_mean, posterior_variance, posterior_log_variance, z_recon + return model_mean, posterior_variance, posterior_log_variance + + def aligned_mean(self, zt, t, zc, y, orig_mean, orig_log_var, **kwargs): + """ + Calculates aligned mean using gradient-based alignment function. + Args: + zt (Tensor): Current latent sample + t (Tensor): Time step tensor + zc (Tensor): Conditioning tensor + y (Tensor): Ground truth tensor + orig_mean (Tensor): Original mean + orig_log_var (Tensor): Original log variance + **kwargs: Additional alignment parameters + Returns: + Tensor: Aligned mean tensor + """ + align_gradient = self.alignment_fn(zt, t, zc=zc, y=y, **kwargs) + new_mean = orig_mean - (0.5 * orig_log_var).exp() * align_gradient + return new_mean + + def p_sample( + self, + zt, + zc, + t, + y=None, + use_alignment=False, + alignment_kwargs=None, + clip_denoised=False, + return_x0=False, + temperature=1.0, + noise_dropout=0.0, + score_corrector=None, + corrector_kwargs=None, + ): + """ + Single step diffusion sampling. + Args: + zt (Tensor): Current noisy sample at time step t + zc (Tensor/Dict): Condition input (latent or processed) + t (Tensor): Time step tensor + y (Tensor, optional): Additional conditioning information + use_alignment (bool): Whether to apply alignment correction + alignment_kwargs (dict, optional): Parameters for alignment correction + clip_denoised (bool): Clip model output to [-1,1] range + return_x0 (bool): Return estimated x0 along with sample + temperature (float): Noise scaling factor + noise_dropout (float): Dropout rate for noise component + score_corrector (object, optional): Model output corrector instance + corrector_kwargs (dict, optional): Parameters for score correction + + Returns: + Tensor: Next denoised sample + Tensor (optional): Estimated x0 if return_x0 is True + """ + batch_size = zt.shape[self.batch_axis] + outputs = self.p_mean_variance( + zt=zt, + zc=zc, + t=t, + clip_denoised=clip_denoised, + return_x0=return_x0, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + ) + if use_alignment: + if alignment_kwargs is None: + alignment_kwargs = {} + model_mean, posterior_variance, model_log_variance, *_ = outputs + model_mean = self.aligned_mean( + zt=zt, + t=t, + zc=zc, + y=y, + orig_mean=model_mean, + orig_log_var=model_log_variance, + **alignment_kwargs, + ) + outputs = (model_mean, posterior_variance, model_log_variance, *outputs[3:]) + if return_x0: + model_mean, _, model_log_variance, x0 = outputs + else: + model_mean, _, model_log_variance = outputs + + noise = noise_like(zt.shape) * temperature + if noise_dropout > 0.0: + noise = ops.dropout(noise, p=noise_dropout) + # no noise when t == 0 + nonzero_mask_shape = [ + 1, + ] * len(zt.shape) + nonzero_mask_shape[self.batch_axis] = batch_size + nonzero_mask = (1 - (t == 0).float()).reshape(*nonzero_mask_shape) + + if return_x0: + return ( + model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, + x0, + ) + return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise + + def p_sample_loop( + self, + cond, + shape, + y=None, + use_alignment=False, + alignment_kwargs=None, + return_intermediates=False, + x_t=None, + verbose=False, + timesteps=None, + mask=None, + x0=None, + start_t=None, + log_every_t=None, + ): + """ + Full diffusion sampling loop. + Args: + cond (Tensor/Dict): Conditioning input (processed) + shape (tuple): Output tensor shape (B, C, H, W) + y (Tensor, optional): Additional conditioning info + use_alignment (bool): Enable alignment correction during sampling + alignment_kwargs (dict, optional): Alignment parameters + return_intermediates (bool): Return intermediate steps + x_t (Tensor, optional): Initial noise sample (default: random) + verbose (bool): Show progress bar + timesteps (int): Number of sampling steps + mask (Tensor, optional): Mask for conditional generation (requires x0) + x0 (Tensor, optional): Original image for inpainting/conditional generation + start_t (int): Override maximum time step + log_every_t (int): Frequency of intermediate saves + + Returns: + Tensor: Final generated sample + list[Tensor] (optional): Intermediate samples if requested + """ + + if not log_every_t: + log_every_t = self.log_every_t + batch_size = shape[self.batch_axis] + if x_t is None: + img = ops.randn(shape) + + else: + img = x_t + + intermediates = [img] + if timesteps is None: + timesteps = self.num_timesteps + + if start_t is not None: + timesteps = min(timesteps, start_t) + iterator = ( + tqdm(reversed(range(0, timesteps)), desc="Sampling t", total=timesteps) + if verbose + else reversed(range(0, timesteps)) + ) + + if mask is not None: + assert x0 is not None + assert x0.shape[2:3] == mask.shape[2:3] # spatial size has to match + for i in iterator: + ts = ops.full((batch_size,), i, dtype=ms.int64) + img = self.p_sample( + zt=img, + zc=cond, + t=ts, + y=y, + use_alignment=use_alignment, + alignment_kwargs=alignment_kwargs, + clip_denoised=self.clip_denoised, + ) + if mask is not None: + img_orig = self.q_sample(x0, ts) + img = img_orig * mask + (1.0 - mask) * img + + if i % log_every_t == 0 or i == timesteps - 1: + intermediates.append(img) + + if return_intermediates: + return img, intermediates + return img + + def sample( + self, + cond, + batch_size=16, + use_alignment=False, + alignment_kwargs=None, + return_intermediates=False, + x_t=None, + verbose=False, + timesteps=None, + mask=None, + x0=None, + shape=None, + return_decoded=True, + ): + """ + High-level sampling interface with conditioning handling. + + Args: + cond (Tensor/Dict): Raw conditioning input (e.g., text/image) + batch_size (int): Number of samples to generate + use_alignment (bool): Enable alignment correction + alignment_kwargs (dict, optional): Alignment parameters + return_intermediates (bool): Return intermediate steps + x_t (Tensor, optional): Initial noise sample + verbose (bool): Show progress + timesteps (int): Sampling steps + mask (Tensor, optional): Inpainting mask (requires x0) + x0 (Tensor, optional): Original image for conditioning + shape (tuple, optional): Output shape override + return_decoded (bool): Return decoded image instead of latent + + Returns: + Tensor: Generated image (decoded if return_decoded) + list[Tensor] (optional): Decoded intermediate steps if requested + """ + if shape is None: + shape = self.get_batch_latent_shape(batch_size=batch_size) + if self.cond_stage_model is not None: + assert cond is not None + cond_tensor_slice = [ + slice(None, None), + ] * len(self.data_shape) + cond_tensor_slice[self.batch_axis] = slice(0, batch_size) + if isinstance(cond, dict): + zc = { + key: ( + cond[key][cond_tensor_slice] + if not isinstance(cond[key], list) + else list(map(lambda x: x[cond_tensor_slice], cond[key])) + ) + for key in cond + } + else: + zc = ( + [c[cond_tensor_slice] for c in cond] + if isinstance(cond, list) + else cond[cond_tensor_slice] + ) + zc = self.cond_stage_forward(zc) + else: + zc = cond if isinstance(cond, Tensor) else cond.get("y", None) + y = cond if isinstance(cond, Tensor) else cond.get("y", None) + output = self.p_sample_loop( + cond=zc, + shape=shape, + y=y, + use_alignment=use_alignment, + alignment_kwargs=alignment_kwargs, + return_intermediates=return_intermediates, + x_t=x_t, + verbose=verbose, + timesteps=timesteps, + mask=mask, + x0=x0, + ) + if return_decoded: + if return_intermediates: + samples, intermediates = output + decoded_samples = self.decode_first_stage(samples) + decoded_intermediates = [ + self.decode_first_stage(ele) for ele in intermediates + ] + output = [decoded_samples, decoded_intermediates] + else: + output = self.decode_first_stage(output) + return output + + + +class PreDiffModule(LatentDiffusion): + """ + Main module for pre-training diffusion models with latent representations. + Integrates configuration loading, model creation, alignment setup, metric initialization, + and visualization parameters. Extends LatentDiffusion to handle cuboid-based UNet architectures + and knowledge alignment for sequential data generation tasks. + """ + + def __init__(self, oc_file: str = None): + self.oc = self._load_configs(oc_file) + latent_model = self._create_latent_model() + first_stage_model = self._create_vae_model() + super().__init__( + **self._prepare_parent_init_params(latent_model, first_stage_model) + ) + self._setup_alignment() + self._initialize_metrics() + self._setup_visualization() + + def _load_configs(self, oc_file): + """Loads all configuration files through a unified entry point.""" + oc_from_file = OmegaConf.load(open(oc_file, "r")) if oc_file else None + return self.get_base_config(oc_from_file=oc_from_file) + + def _create_latent_model(self): + """Builds the CuboidTransformerUNet model based on configurations.""" + latent_model_cfg = OmegaConf.to_object(self.oc.model.latent_model) + return CuboidTransformerUNet( + **{ + k: latent_model_cfg[k] + for k in latent_model_cfg + }, + ) + + def _process_attention_patterns(self, cfg, num_blocks): + """Processes attention patterns from configuration settings.""" + if isinstance(cfg["self_pattern"], str): + return [cfg["self_pattern"]] * num_blocks + return OmegaConf.to_container(cfg["self_pattern"]) + + def _create_vae_model(self): + """Creates and loads pretrained weights for the VAE model.""" + vae_cfg = OmegaConf.to_object(self.oc.model.vae) + model = AutoencoderKL( + **{ + k: vae_cfg[k] + for k in vae_cfg + if k not in ["pretrained_ckpt_path", "data_channels"] + } + ) + self._load_pretrained_weights(model, vae_cfg["pretrained_ckpt_path"]) + return model + + def _load_pretrained_weights(self, model, ckpt_path): + """Loads pretrained weights into the given model if a checkpoint path is provided.""" + if ckpt_path: + param_dict = ms.load_checkpoint(ckpt_path) + param_not_load, _ = ms.load_param_into_net(model, param_dict) + if param_not_load: + print(f"Unloaded AutoencoderKLparameters: {param_not_load}") + else: + warnings.warn( + "Pretrained weights for AutoencoderKL not set. Running sanity check only." + ) + + def _prepare_parent_init_params(self, latent_model, first_stage_model): + """Prepares initialization parameters for the parent class.""" + diffusion_cfg = OmegaConf.to_object(self.oc.model.diffusion) + return { + "main_model": latent_model, + "layout": self.oc.layout.layout, + "loss_type": self.oc.optim.loss_type, + "monitor": self.oc.optim.monitor, + "first_stage_model": first_stage_model, + **{ + k: diffusion_cfg[k] + for k in diffusion_cfg + if k not in ["latent_cond_shape"] + }, + } + + def _setup_alignment(self): + """Sets up alignment using AvgIntensityAlignment if specified in configurations.""" + # from src.knowledge_alignment.alignment_net import AvgIntensityAlignment + + knowledge_cfg = OmegaConf.to_object(self.oc.model.align) + self.alignment_type = knowledge_cfg["alignment_type"] + self.use_alignment = self.alignment_type is not None + + if self.use_alignment: + self.alignment_obj = AvgIntensityAlignment( + guide_scale=knowledge_cfg["guide_scale"], + model_args=knowledge_cfg["model_args"], + model_ckpt_path=knowledge_cfg["model_ckpt_path"], + ) + self.alignment_obj.model.set_train(False) + self.set_alignment(self.alignment_obj.get_mean_shift) + else: + self.set_alignment(None) + + def _initialize_metrics(self): + """Initializes metrics for evaluation based on configurations.""" + if self.oc.eval.eval_unaligned: + self._init_unaligned_metrics() + if self.oc.eval.eval_aligned: + self._init_aligned_metrics() + + def _init_unaligned_metrics(self): + """Initializes unaligned metrics for evaluation.""" + common_args = { + "mode": self.oc.data.metrics_mode, + "seq_in": self.oc.layout.t_out, + "layout": self.layout, + "threshold_list": self.oc.data.threshold_list, + "metrics_list": self.oc.data.metrics_list, + "eps": 1e-4, + } + + self.valid_score = SEVIRSkillScore(**common_args) + + self.test_ssim = calculate_ssim + self.test_aligned_ssim = calculate_ssim + self.test_score = SEVIRSkillScore(**common_args) + + def _init_aligned_metrics(self): + """Initializes aligned metrics for evaluation.""" + common_args = { + "mode": self.oc.data.metrics_mode, + "seq_in": self.oc.layout.t_out, + "layout": self.layout, + "threshold_list": self.oc.data.threshold_list, + "metrics_list": self.oc.data.metrics_list, + "eps": 1e-4, + } + + self.valid_aligned_score = SEVIRSkillScore(**common_args) + + self.test_aligned_ssim = nn.SSIM() + self.test_aligned_score = SEVIRSkillScore(**common_args) + + def _setup_visualization(self): + """Sets up visualization parameters based on configurations.""" + self.logging_prefix = self.oc.logging.logging_prefix + self.train_example_data_idx_list = list( + self.oc.eval.train_example_data_idx_list + ) + self.val_example_data_idx_list = list(self.oc.eval.val_example_data_idx_list) + self.test_example_data_idx_list = list(self.oc.eval.test_example_data_idx_list) + + def get_base_config(self, oc_from_file=None): + """Merges base configuration with configuration loaded from file.""" + if oc_from_file is None: + raise ValueError("oc_from_file is required but not provided") + oc = OmegaConf.create() + oc = OmegaConf.merge(oc, oc_from_file) + return oc + + @classmethod + def get_total_num_steps( + cls, num_samples: int, total_batch_size: int, epoch: int = None + ): + """ + Parameters + ---------- + num_samples: int + The number of samples of the datasets. `num_samples / micro_batch_size` is the number of steps per epoch. + total_batch_size: int + `total_batch_size == micro_batch_size * world_size * grad_accum` + epoch: int + """ + if epoch is None: + epoch = cls.get_optim_config().max_epochs + return int(epoch * num_samples / total_batch_size) + + @staticmethod + def get_sevir_datamodule( + dataset_cfg, micro_batch_size: int = 1, num_workers: int = 8 + ): + """Creates and returns a SEVIRDataModule instance based on dataset configurations.""" + dm = SEVIRDataModule( + sevir_dir=dataset_cfg["root_dir"], + seq_in=dataset_cfg["seq_in"], + sample_mode=dataset_cfg["sample_mode"], + stride=dataset_cfg["stride"], + batch_size=micro_batch_size, + layout=dataset_cfg["layout"], + output_type=np.float32, + preprocess=True, + rescale_method="01", + verbose=False, + aug_mode=dataset_cfg["aug_mode"], + dataset_name=dataset_cfg["dataset_name"], + start_date=dataset_cfg["start_date"], + train_val_split_date=dataset_cfg["train_val_split_date"], + train_test_split_date=dataset_cfg["train_test_split_date"], + end_date=dataset_cfg["end_date"], + val_ratio=dataset_cfg["val_ratio"], + num_workers=num_workers, + raw_seq_len=dataset_cfg["raw_seq_len"] + ) + return dm + + @property + def in_slice(self): + """Returns the input slice based on layout and sequence length configurations.""" + if not hasattr(self, "_in_slice"): + in_slice, out_slice = layout_to_in_out_slice( + layout=self.oc.layout.layout, + t_in=self.oc.layout.t_in, + t_out=self.oc.layout.t_out, + ) + self._in_slice = in_slice + self._out_slice = out_slice + return self._in_slices + + @property + def out_slice(self): + """Returns the output slice based on layout and sequence length configurations.""" + if not hasattr(self, "_out_slice"): + in_slice, out_slice = layout_to_in_out_slice( + layout=self.oc.layout.layout, + t_in=self.oc.layout.t_in, + t_out=self.oc.layout.t_out, + ) + self._in_slice = in_slice + self._out_slice = out_slice + return self._out_slice + + def get_input(self, batch, **kwargs): + """Extracts input data and conditioning information from a raw data batch.""" + return self._get_input_sevirlr( + batch=batch, return_verbose=kwargs.get("return_verbose", False) + ) + + def _get_input_sevirlr(self, batch, return_verbose=False): + """Specific implementation of input extraction for SEVIRLR dataset.""" + seq = batch + in_seq = seq[self.in_slice] + out_seq = seq[self.out_slice] + if return_verbose: + return out_seq, {"y": in_seq}, in_seq + return out_seq, {"y": in_seq} diff --git a/MindEarth/applications/nowcasting/PreDiff/src/diffusion/solver.py b/MindEarth/applications/nowcasting/PreDiff/src/diffusion/solver.py new file mode 100644 index 0000000000000000000000000000000000000000..103cf70643de7aa19105dacb28116e31938d750a --- /dev/null +++ b/MindEarth/applications/nowcasting/PreDiff/src/diffusion/solver.py @@ -0,0 +1,140 @@ +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"diffusion model training" +import time +import os + + +import mindspore as ms +from mindspore import ops, nn +from mindspore.train.serialization import save_checkpoint + +from src.sevir_dataset import SEVIRDataset + + +class DiffusionTrainer(nn.Cell): + """ + Class managing the training pipeline for diffusion models. Handles dataset processing, + optimizer configuration, gradient clipping, checkpoint saving, and logging. + """ + def __init__(self, main_module, dm, logger, config): + """ + Initialize trainer with model, data module, logger, and configuration. + Args: + main_module: Main diffusion model to be trained + dm: Data module providing training dataset + logger: Logging utility for training progress + config: Configuration dictionary containing hyperparameters + """ + super().__init__() + self.main_module = main_module + self.traindataset = dm.sevir_train + self.logger = logger + self.datasetprocessing = SEVIRDataset( + data_types=["vil"], + layout="NHWT", + rescale_method=config.get("rescale_method", "01"), + ) + self.example_save_dir = config["summary"].get("summary_dir", "./summary") + self.fs = config["eval"].get("fs", 20) + self.label_offset = config["eval"].get("label_offset", [-0.5, 0.5]) + self.label_avg_int = config["eval"].get("label_avg_int", False) + self.current_epoch = 0 + self.learn_logvar = ( + config.get("model", {}).get("diffusion", {}).get("learn_logvar", False) + ) + self.logvar = main_module.logvar + self.maeloss = nn.MAELoss() + self.optim_config = config["optim"] + self.clip_norm = config.get("clip_norm", 2) + self.ckpt_dir = os.path.join(self.example_save_dir, "ckpt") + self.keep_ckpt_max = config["summary"].get("keep_ckpt_max", 100) + self.ckpt_history = [] + self.grad_clip_fn = ops.clip_by_global_norm + self.optimizer = nn.Adam(params=self.main_module.main_model.trainable_params(), + learning_rate=config["optim"].get("lr", 1e-5)) + os.makedirs(self.ckpt_dir, exist_ok=True) + + def train(self, total_steps: int): + """Execute complete training pipeline.""" + self.main_module.main_model.set_train(True) + self.logger.info(f"total_steps: {total_steps}") + self.logger.info("Initializing training process...") + loss_processor = Trainonestepforward(self.main_module) + grad_func = ms.ops.value_and_grad(loss_processor, None, self.main_module.main_model.trainable_params()) + for epoch in range(self.optim_config["max_epochs"]): + epoch_loss = 0.0 + epoch_start = time.time() + + iterator = self.traindataset.create_dict_iterator() + assert iterator, "dataset is empty" + batch_idx = 0 + for batch_idx, batch in enumerate(iterator): + processed_data = self.datasetprocessing.process_data(batch["vil"]) + loss_value, gradients = grad_func(processed_data) + clipped_grads = self.grad_clip_fn(gradients, self.clip_norm) + self.optimizer(clipped_grads) + epoch_loss += loss_value.asnumpy() + self.logger.info( + f"epoch: {epoch} step: {batch_idx}, loss: {loss_value}" + ) + self._save_ckpt(epoch) + epoch_time = time.time() - epoch_start + self.logger.info( + f"Epoch {epoch} completed in {epoch_time:.2f}s | " + f"Avg Loss: {epoch_loss/(batch_idx+1):.4f}" + ) + + def _save_ckpt(self, epoch: int): + """Save model ckpt with rotation policy""" + ckpt_file = f"diffusion_epoch{epoch}.ckpt" + ckpt_path = os.path.join(self.ckpt_dir, ckpt_file) + + save_checkpoint(self.main_module.main_model, ckpt_path) + self.ckpt_history.append(ckpt_path) + + if len(self.ckpt_history) > self.keep_ckpt_max: + removed_ckpt = self.ckpt_history.pop(0) + os.remove(removed_ckpt) + self.logger.info(f"Removed outdated ckpt: {removed_ckpt}") + + +class Trainonestepforward(nn.Cell): + """A neural network cell that performs one training step forward pass for a diffusion model. + This class encapsulates the forward pass computation for training a diffusion model, + handling the input processing, latent space encoding, conditioning, and loss calculation. + Args: + model (nn.Cell): The main diffusion model containing the necessary submodules + for encoding, conditioning, and loss computation. + """ + + def __init__(self, model): + super().__init__() + self.main_module = model + + def construct(self, inputs): + """Perform one forward training step and compute the loss.""" + x, condition = self.main_module.get_input(inputs) + x = x.transpose(0, 1, 4, 2, 3) + n, t_, c_, h, w = x.shape + x = x.reshape(n * t_, c_, h, w) + z = self.main_module.encode_first_stage(x) + _, c_z, h_z, w_z = z.shape + z = z.reshape(n, -1, c_z, h_z, w_z) + z = z.transpose(0, 1, 3, 4, 2) + t = ops.randint(0, self.main_module.num_timesteps, (n,)).long() + zc = self.main_module.cond_stage_forward(condition) + loss = self.main_module.p_losses(z, zc, t, noise=None) + return loss diff --git a/MindEarth/applications/nowcasting/PreDiff/src/diffusion/time_embed.py b/MindEarth/applications/nowcasting/PreDiff/src/diffusion/time_embed.py new file mode 100644 index 0000000000000000000000000000000000000000..1642d0c295c14a7d002fc9b2023802a764109b27 --- /dev/null +++ b/MindEarth/applications/nowcasting/PreDiff/src/diffusion/time_embed.py @@ -0,0 +1,270 @@ +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"TimeEmbedLayer and TimeEmbedResBlock" +from mindspore import nn, ops + +from src.utils import conv_nd, apply_initialization, avg_pool_nd + + +class Upsample(nn.Cell): + """ + An upsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + upsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + if use_conv: + self.conv = conv_nd( + dims, self.channels, self.out_channels, 3, padding=padding + ) + + def construct(self, x): + '''upsample forward''' + assert x.shape[1] == self.channels + if self.dims == 3: + x = ops.interpolate( + x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest" + ) + else: + x = ops.interpolate(x, scale_factor=2, mode="nearest") + if self.use_conv: + x = self.conv(x) + return x + + +class Downsample(nn.Cell): + """ + A downsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + downsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + stride = 2 if dims != 3 else (1, 2, 2) + if use_conv: + self.op = conv_nd( + dims, + self.channels, + self.out_channels, + 3, + stride=stride, + padding=padding, + ) + else: + assert self.channels == self.out_channels + self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) + + def construct(self, x): + assert x.shape[1] == self.channels + return self.op(x) + + +class TimeEmbedLayer(nn.Cell): + """ + A neural network layer that embeds time information into a higher-dimensional space. + + The layer consists of two linear layers separated by a SiLU activation function. + It takes an input tensor with a specified number of base channels and transforms it + into a tensor with a specified number of time embedding channels. + Parameters: + - base_channels (int): Number of channels in the input tensor. + - time_embed_channels (int): Number of channels in the output embedded tensor. + - linear_init_mode (str, optional): Initialization mode for the linear layers. Defaults to "0". + """ + + def __init__(self, base_channels, time_embed_channels, linear_init_mode="0"): + super().__init__() + self.layer = nn.SequentialCell( + nn.Dense(base_channels, time_embed_channels), + nn.SiLU(), + nn.Dense(time_embed_channels, time_embed_channels), + ) + self.linear_init_mode = linear_init_mode + + def construct(self, x): + """Forward pass through the TimeEmbedLayer.""" + return self.layer(x) + + def reset_parameters(self): + """Reset the parameters of the linear layers in the TimeEmbedLayer.""" + apply_initialization(self.layer[0], linear_mode=self.linear_init_mode) + apply_initialization(self.layer[2], linear_mode=self.linear_init_mode) + + +class TimeEmbedResBlock(nn.Cell): + r""" + Modifications: + 1. Change GroupNorm32 to use arbitrary `num_groups`. + 2. Add method `self.reset_parameters()`. + 3. Use gradient ckpt from mindspore instead of the stable diffusion implementation + 4. If no input time embed, it degrades to res block. + """ + + def __init__( + self, + channels, + dropout, + emb_channels=None, + out_channels=None, + use_conv=False, + use_embed=True, + use_scale_shift_norm=False, + dims=2, + up=False, + down=False, + norm_groups=32, + ): + r""" + Parameters + ---------- + channels + dropout + emb_channels + out_channels + use_conv + use_embed: bool + include `emb` as input in `self.forward()` + use_scale_shift_norm: bool + take effect only when `use_embed == True` + dims + up + down + norm_groups + """ + super().__init__() + self.channels = channels + self.dropout = dropout + self.use_embed = use_embed + if use_embed: + assert isinstance(emb_channels, int) + self.emb_channels = emb_channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_scale_shift_norm = use_scale_shift_norm + + self.in_layers = nn.SequentialCell( + nn.GroupNorm( + num_groups=norm_groups if channels % norm_groups == 0 else channels, + num_channels=channels, + ), + nn.SiLU(), + conv_nd(dims, channels, self.out_channels, 3, padding=1), + ) + self.updown = up or down + + if up: + self.h_upd = Upsample(channels, False, dims) + self.x_upd = Upsample(channels, False, dims) + elif down: + self.h_upd = Downsample(channels, False, dims) + self.x_upd = Downsample(channels, False, dims) + else: + self.h_upd = self.x_upd = nn.Identity() + + if use_embed: + self.emb_layers = nn.SequentialCell( + nn.SiLU(), + nn.Dense( + in_channels=emb_channels, + out_channels=( + 2 * self.out_channels + if use_scale_shift_norm + else self.out_channels + ), + ), + ) + self.out_layers = nn.SequentialCell( + nn.GroupNorm( + num_groups=( + norm_groups + if self.out_channels % norm_groups == 0 + else self.out_channels + ), + num_channels=self.out_channels, + ), + nn.SiLU(), + nn.Dropout(p=dropout), + # nn.Dropout(p=0), + conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1), + ) + if self.out_channels == channels: + self.skip_connection = nn.Identity() + elif use_conv: + self.skip_connection = conv_nd( + dims, channels, self.out_channels, 3, padding=1 + ) + else: + self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) + + self.reset_parameters() + + def construct(self, x, emb=None): + """ + Apply the block to a Tensor, conditioned on a timestep embedding. + + Parameters + ---------- + x: an [N x C x ...] Tensor of features. + emb: an [N x emb_channels] Tensor of timestep embeddings. + + Returns + ------- + out: an [N x C x ...] Tensor of outputs. + """ + if self.updown: + in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] + h = in_rest(x) + h = self.h_upd(h) + x = self.x_upd(x) + h = in_conv(h) + else: + h = self.in_layers(x) + if self.use_embed: + emb_out = self.emb_layers(emb).astype(h.dtype) + while len(emb_out.shape) < len(h.shape): + emb_out = emb_out[..., None] + if self.use_scale_shift_norm: + out_norm, out_rest = self.out_layers[0], self.out_layers[1:] + scale, shift = ops.chunk(emb_out, 2, axis=1) + h = out_norm(h) * (1 + scale) + shift + h = out_rest(h) + else: + h = h + emb_out + h = self.out_layers(h) + else: + h = self.out_layers(h) + n = self.skip_connection(x) + h + return n + + def reset_parameters(self): + for _, cell in self.cells_and_names(): + apply_initialization(cell) + for p in self.out_layers[-1].get_parameters(): + p.set_data(ops.zeros(p.shape, dtype=p.dtype)) diff --git a/MindEarth/applications/nowcasting/PreDiff/src/knowledge_alignment/__init__.py b/MindEarth/applications/nowcasting/PreDiff/src/knowledge_alignment/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..eeb2f631dea92f088fd06c72f6c02492175df1c6 --- /dev/null +++ b/MindEarth/applications/nowcasting/PreDiff/src/knowledge_alignment/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this filepio[] except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""init""" diff --git a/MindEarth/applications/nowcasting/PreDiff/src/knowledge_alignment/alignment.py b/MindEarth/applications/nowcasting/PreDiff/src/knowledge_alignment/alignment.py new file mode 100644 index 0000000000000000000000000000000000000000..f1a6785b3b2b6a7ecdeb3a53d94cfdaff3ac47a9 --- /dev/null +++ b/MindEarth/applications/nowcasting/PreDiff/src/knowledge_alignment/alignment.py @@ -0,0 +1,566 @@ +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"NoisyCuboidTransformerEncoder" +import math +import numpy as np + +import mindspore as ms +from mindspore import nn, ops, mint +from mindspore.common.initializer import TruncatedNormal + +from src.utils import ( + conv_nd, + zero_module, + timestep_embedding, + apply_initialization, + round_to, + self_axial +) +from src.diffusion import ( + PatchMerging3D, + PosEmbed, + StackCuboidSelfAttentionBlock, + TimeEmbedLayer, + TimeEmbedResBlock, +) + + +class QKVAttention(nn.Cell): + """ + A module which performs QKV attention and splits in a different order. + """ + + def __init__(self, n_heads): + super().__init__() + self.n_heads = n_heads + + def construct(self, qkv): + """ + Apply QKV attention. + :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs. + :return: an [N x (H * C) x T] tensor after attention. + """ + bs, width, length = qkv.shape + assert width % (3 * self.n_heads) == 0 + ch = width // (3 * self.n_heads) + q, k, v = qkv.chunk(3, dim=1) + scale = 1 / math.sqrt(math.sqrt(ch)) + q_transposed = ops.transpose( + (q * scale).view(bs * self.n_heads, ch, length), (0, 2, 1) + ) + k_reshaped = (k * scale).view(bs * self.n_heads, ch, length) + weight = ops.BatchMatMul()(q_transposed, k_reshaped) + weight = nn.Softmax(axis=-1)(weight.float()).type(weight.dtype) + weight_transposed = ops.transpose(weight, (0, 2, 1)) + v_reshaped = v.reshape(bs * self.n_heads, ch, length) + a = ops.BatchMatMul()(v_reshaped, weight_transposed) + return a.reshape(bs, -1, length) + + +class AttentionPool3d(nn.Cell): + """ + Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py + """ + + def __init__( + self, + data_dim: int, + embed_dim: int, + num_heads: int, + output_dim: int = None, + init_mode: str = "0", + ): + r""" + Parameters + ---------- + data_dim: int + e.g. T*H*W if data is 3D + embed_dim: int + input data channels + num_heads: int + output_dim: int + """ + super().__init__() + self.positional_embedding = ms.Parameter( + ops.randn(embed_dim, data_dim + 1) / embed_dim**0.5 + ) + self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1) + self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1) + self.num_heads = num_heads + self.attention = QKVAttention(self.num_heads) + self.init_mode = init_mode + + def construct(self, x): + r""" + + Parameters + ---------- + x: ms.Tensor + layout = "NCTHW" + + Returns + ------- + ret: ms.Tensor + layout = "NC" + """ + b, c, _ = x.shape + x = x.reshape(b, c, -1) + x = mint.cat([x.mean(axis=-1, keep_dims=True), x], dim=-1) + x = x + self.positional_embedding[None, :, :].to(x.dtype) + x = self.qkv_proj(x) + x = self.attention(x) + x = self.c_proj(x) + return x[:, :, 0] + + def reset_parameters(self): + '''set parameters''' + apply_initialization(self.qkv_proj, conv_mode="0") + apply_initialization(self.c_proj, conv_mode=self.init_mode) + + +class NoisyCuboidTransformerEncoder(nn.Cell): + r""" + Half U-Net style CuboidTransformerEncoder that parametrizes `U(z_t, t, ...)`. + It takes `x_t`, `t` as input. + The conditioning can be concatenated to the input like the U-Net in FVD paper. + + For each block, we apply the StackCuboidSelfAttention. The final block state is read out by a pooling layer. + + x --> attn --> downscale --> ... --> poll --> out + + Besides, we insert the embeddings of the timesteps `t` before each cuboid attention blocks. + """ + + def __init__( + self, + input_shape=None, + out_channels=1, + base_units=128, + block_units=None, + scale_alpha=1.0, + depth=None, + downsample=2, + downsample_type="patch_merge", + use_attn_pattern=None, + block_cuboid_size=None, + block_cuboid_strategy=None, + block_cuboid_shift_size=None, + num_heads=4, + attn_drop=0.0, + proj_drop=0.0, + ffn_drop=0.0, + ffn_activation="gelu", + gated_ffn=False, + norm_layer="layer_norm", + use_inter_ffn=True, + hierarchical_pos_embed=False, + padding_type="zeros", + use_relative_pos=True, + self_attn_use_final_proj=True, + # global vectors + num_global_vectors=0, + use_global_vector_ffn=True, + use_global_self_attn=False, + separate_global_qkv=False, + global_dim_ratio=1, + # initialization + attn_linear_init_mode="0", + ffn_linear_init_mode="0", + ffn2_linear_init_mode="2", + attn_proj_linear_init_mode="2", + conv_init_mode="0", + down_linear_init_mode="0", + global_proj_linear_init_mode="2", + norm_init_mode="0", + # timestep embedding for diffusion + time_embed_channels_mult=4, + time_embed_use_scale_shift_norm=False, + time_embed_dropout=0.0, + # readout + pool: str = "attention", + readout_seq: bool = True, + t_out: int = None, + ): + super().__init__() + # initialization mode + self.attn_linear_init_mode = attn_linear_init_mode + self.ffn_linear_init_mode = ffn_linear_init_mode + self.ffn2_linear_init_mode = ffn2_linear_init_mode + self.attn_proj_linear_init_mode = attn_proj_linear_init_mode + self.conv_init_mode = conv_init_mode + self.down_linear_init_mode = down_linear_init_mode + self.global_proj_linear_init_mode = global_proj_linear_init_mode + self.norm_init_mode = norm_init_mode + + self.input_shape = input_shape + self.out_channels = out_channels + self.num_blocks = len(depth) + self.depth = depth + self.base_units = base_units + self.scale_alpha = scale_alpha + self.downsample = downsample + self.downsample_type = downsample_type + if not isinstance(downsample, (tuple, list)): + downsample = (1, downsample, downsample) + if block_units is None: + block_units = [ + round_to(base_units * int((max(downsample) ** scale_alpha) ** i), 4) + for i in range(self.num_blocks) + ] + else: + assert len(block_units) == self.num_blocks and block_units[0] == base_units + self.block_units = block_units + self.hierarchical_pos_embed = hierarchical_pos_embed + self.num_global_vectors = num_global_vectors + use_global_vector = num_global_vectors > 0 + self.use_global_vector = use_global_vector + self.global_dim_ratio = global_dim_ratio + self.use_global_vector_ffn = use_global_vector_ffn + + self.time_embed_channels_mult = time_embed_channels_mult + self.time_embed_channels = self.block_units[0] * time_embed_channels_mult + self.time_embed_use_scale_shift_norm = time_embed_use_scale_shift_norm + self.time_embed_dropout = time_embed_dropout + self.pool = pool + self.readout_seq = readout_seq + self.t_out = t_out + + if self.use_global_vector: + self.init_global_vectors = ms.Parameter( + mint.zeros((self.num_global_vectors, global_dim_ratio * base_units)) + ) + + _, h_in, w_in, _ = input_shape + self.first_proj = TimeEmbedResBlock( + channels=input_shape[-1], + emb_channels=None, + dropout=proj_drop, + out_channels=self.base_units, + use_conv=False, + use_embed=False, + use_scale_shift_norm=False, + dims=3, + up=False, + down=False, + ) + self.pos_embed = PosEmbed( + embed_dim=base_units, + max_t=input_shape[0], + max_h=h_in, + max_w=w_in, + ) + + # diffusion time embed + self.time_embed = TimeEmbedLayer( + base_channels=self.block_units[0], + time_embed_channels=self.time_embed_channels, + ) + if self.num_blocks > 1: + if downsample_type == "patch_merge": + self.downsample_layers = nn.CellList( + [ + PatchMerging3D( + dim=self.block_units[i], + downsample=downsample, + padding_type=padding_type, + out_dim=self.block_units[i + 1], + linear_init_mode=down_linear_init_mode, + norm_init_mode=norm_init_mode, + ) + for i in range(self.num_blocks - 1) + ] + ) + else: + raise NotImplementedError + if self.use_global_vector: + self.down_layer_global_proj = nn.CellList( + [ + mint.nn.Linear( + in_features=global_dim_ratio * self.block_units[i], + out_features=global_dim_ratio * self.block_units[i + 1], + ) + for i in range(self.num_blocks - 1) + ] + ) + if self.hierarchical_pos_embed: + self.down_hierarchical_pos_embed_l = nn.CellList( + [ + PosEmbed( + embed_dim=self.block_units[i], + max_t=self.mem_shapes[i][0], + max_h=self.mem_shapes[i][1], + max_w=self.mem_shapes[i][2], + ) + for i in range(self.num_blocks - 1) + ] + ) + + if use_attn_pattern: + block_attn_patterns = self.depth + block_cuboid_size = [] + block_cuboid_strategy = [] + block_cuboid_shift_size = [] + for idx, _ in enumerate(block_attn_patterns): + cuboid_size, strategy, shift_size = self_axial(self.mem_shapes[idx]) + block_cuboid_size.append(cuboid_size) + block_cuboid_strategy.append(strategy) + block_cuboid_shift_size.append(shift_size) + else: + if not isinstance(block_cuboid_size[0][0], (list, tuple)): + block_cuboid_size = [block_cuboid_size for _ in range(self.num_blocks)] + else: + assert ( + len(block_cuboid_size) == self.num_blocks + ), f"Incorrect input format! Received block_cuboid_size={block_cuboid_size}" + + if not isinstance(block_cuboid_strategy[0][0], (list, tuple)): + block_cuboid_strategy = [ + block_cuboid_strategy for _ in range(self.num_blocks) + ] + else: + assert ( + len(block_cuboid_strategy) == self.num_blocks + ), f"Incorrect input format! Received block_strategy={block_cuboid_strategy}" + + if not isinstance(block_cuboid_shift_size[0][0], (list, tuple)): + block_cuboid_shift_size = [ + block_cuboid_shift_size for _ in range(self.num_blocks) + ] + else: + assert ( + len(block_cuboid_shift_size) == self.num_blocks + ), f"Incorrect input format! Received block_shift_size={block_cuboid_shift_size}" + self.block_cuboid_size = block_cuboid_size + self.block_cuboid_strategy = block_cuboid_strategy + self.block_cuboid_shift_size = block_cuboid_shift_size + + # cuboid self attention blocks + down_self_blocks = [] + # ResBlocks that incorporate `time_embed` + down_time_embed_blocks = [] + for i in range(self.num_blocks): + down_time_embed_blocks.append( + TimeEmbedResBlock( + channels=self.mem_shapes[i][-1], + emb_channels=self.time_embed_channels, + dropout=self.time_embed_dropout, + out_channels=self.mem_shapes[i][-1], + use_conv=False, + use_embed=True, + use_scale_shift_norm=self.time_embed_use_scale_shift_norm, + dims=3, + up=False, + down=False, + ) + ) + + ele_depth = depth[i] + + stack_cuboid_blocks = [ + StackCuboidSelfAttentionBlock( + dim=self.mem_shapes[i][-1], + num_heads=num_heads, + block_cuboid_size=block_cuboid_size[i], + block_strategy=block_cuboid_strategy[i], + block_shift_size=block_cuboid_shift_size[i], + attn_drop=attn_drop, + proj_drop=proj_drop, + ffn_drop=ffn_drop, + activation=ffn_activation, + gated_ffn=gated_ffn, + norm_layer=norm_layer, + use_inter_ffn=use_inter_ffn, + padding_type=padding_type, + use_global_vector=use_global_vector, + use_global_vector_ffn=use_global_vector_ffn, + use_global_self_attn=use_global_self_attn, + separate_global_qkv=separate_global_qkv, + global_dim_ratio=global_dim_ratio, + use_relative_pos=use_relative_pos, + use_final_proj=self_attn_use_final_proj, + # initialization + attn_linear_init_mode=attn_linear_init_mode, + ffn_linear_init_mode=ffn_linear_init_mode, + ffn2_linear_init_mode=ffn2_linear_init_mode, + attn_proj_linear_init_mode=attn_proj_linear_init_mode, + norm_init_mode=norm_init_mode, + ) + for _ in range(ele_depth) + ] + down_self_blocks.append(nn.CellList(stack_cuboid_blocks)) + + self.down_self_blocks = nn.CellList(down_self_blocks) + self.down_time_embed_blocks = nn.CellList(down_time_embed_blocks) + + out_shape = self.mem_shapes[-1] + cuboid_out_channels = out_shape[-1] + if pool == "adaptive": + self.out = nn.SequentialCell( + nn.GroupNorm(min(cuboid_out_channels, 32), cuboid_out_channels), + nn.SiLU(), + nn.AdaptiveAvgPool2d((1, 1)), + zero_module(conv_nd(2, cuboid_out_channels, out_channels, 1)), + nn.Flatten(), + ) + elif pool == "attention": + if readout_seq: + data_dim = np.prod(out_shape[1:-1]).item() + num_global_vectors + else: + data_dim = np.prod(out_shape[:-1]).item() + num_global_vectors + self.out = nn.SequentialCell( + nn.GroupNorm(min(cuboid_out_channels, 32), cuboid_out_channels), + nn.SiLU(), + AttentionPool3d( + data_dim, + cuboid_out_channels, + num_heads, + out_channels, + init_mode="0", + ), + ) + elif pool == "spatial": + self.out = nn.SequentialCell( + mint.nn.Linear(self._feature_size, 2048), + mint.nn.ReLU(), + mint.nn.Linear(2048, out_channels), + ) + elif pool == "spatial_v2": + self.out = nn.SequentialCell( + mint.nn.Linear(self._feature_size, 2048), + mint.nn.GroupNorm(2048, 2048), + nn.SiLU(), + mint.nn.Linear(2048, out_channels), + ) + else: + raise NotImplementedError(f"Unexpected {pool} pooling") + + self.reset_parameters() + + def reset_parameters(self): + """set parameters""" + if self.num_global_vectors > 0: + TruncatedNormal(self.init_global_vectors, sigma=0.02) + self.first_proj.reset_parameters() + self.pos_embed.reset_parameters() + # inner U-Net + for block in self.down_self_blocks: + for m in block: + m.reset_parameters() + for m in self.down_time_embed_blocks: + m.reset_parameters() + if self.num_blocks > 1: + for m in self.downsample_layers: + m.reset_parameters() + if self.use_global_vector: + apply_initialization( + self.down_layer_global_proj, + linear_mode=self.global_proj_linear_init_mode, + ) + if self.hierarchical_pos_embed: + for m in self.down_hierarchical_pos_embed_l: + m.reset_parameters() + if self.pool == "attention": + apply_initialization(self.out[0], norm_mode=self.norm_init_mode) + self.out[2].reset_parameters() + else: + raise NotImplementedError + + def transpose_and_first_proj(self, x, batch_size): + """transpose and first_proj""" + x = x.transpose(0, 4, 1, 2, 3) + x = self.first_proj(x) + x = x.transpose(0, 2, 3, 4, 1) + if self.use_global_vector: + global_vectors = self.init_global_vectors.broadcast_to( + batch_size, + self.num_global_vectors, + self.global_dim_ratio * self.base_units, + ) + return x, global_vectors + return x, None + + @property + def mem_shapes(self): + """Get the shape of the output memory based on the input shape. This can be used for constructing the decoder. + + Returns + ------- + mem_shapes + A list of shapes of the output memory + """ + inner_data_shape = tuple(self.input_shape)[:3] + (self.base_units,) + if self.num_blocks == 1: + return [inner_data_shape] + mem_shapes = [inner_data_shape] + curr_shape = inner_data_shape + for down_layer in self.downsample_layers: + curr_shape = down_layer.get_out_shape(curr_shape) + mem_shapes.append(curr_shape) + return mem_shapes + + def construct(self, x, t): + """ + Forward pass through the NoisyCuboidTransformerEncoder. + + Parameters: + - x (Tensor): Input tensor of shape (batch_size, seq_in, H, W, C). + - t (Tensor): Timestep tensor. + + Returns: + - Tensor: Output tensor after processing through the encoder. + """ + batch_size, seq_in, _, _, _ = x.shape + x, global_vectors = self.transpose_and_first_proj(x, batch_size) + x = self.pos_embed(x) + t_emb = self.time_embed(timestep_embedding(t, self.block_units[0])) + for i in range(self.num_blocks): + if i > 0: + x = self.downsample_layers[i - 1](x) + if self.hierarchical_pos_embed: + x = self.down_hierarchical_pos_embed_l[i - 1](x) + if self.use_global_vector: + global_vectors = self.down_layer_global_proj[i - 1](global_vectors) + for idx in range(self.depth[i]): + x = x.transpose(0, 4, 1, 2, 3) + x = self.down_time_embed_blocks[i](x, t_emb) + x = x.transpose(0, 2, 3, 4, 1) + if self.use_global_vector: + x, global_vectors = self.down_self_blocks[i][idx](x, global_vectors) + else: + x = self.down_self_blocks[i][idx](x) + + if self.readout_seq: + if self.t_out is not None: + seq_in = self.t_out + start_idx = x.shape[1] - self.t_out + x = x[:, start_idx:, ...] + out = x.transpose(0, 1, 4, 2, 3) + b_out, t_out, c_out, h_out, w_out = out.shape + out = out.reshape(b_out * t_out, c_out, h_out * w_out) + if self.num_global_vectors > 0: + out_global = global_vectors.tile((seq_in, 1, 1)) + out_global = out_global.transpose(0, 2, 1) + out = mint.cat([out, out_global], dim=2) + out = self.out(out) + out = out.reshape(batch_size, seq_in, -1) + else: + out = x.transpose(0, 4, 1, 2, 3) + b_out, c_out, t_out, h_out, w_out = out.shape + out = out.reshape(b_out, c_out, t_out * h_out * w_out) + if self.num_global_vectors > 0: + out_global = global_vectors.transpose(0, 2, 1) + out = mint.cat([out, out_global], dim=2) + out = self.out(out) + return out diff --git a/MindEarth/applications/nowcasting/PreDiff/src/knowledge_alignment/alignment_net.py b/MindEarth/applications/nowcasting/PreDiff/src/knowledge_alignment/alignment_net.py new file mode 100644 index 0000000000000000000000000000000000000000..3cfeee2e2cd8bf7c8c10c7a20876c8956c9ac218 --- /dev/null +++ b/MindEarth/applications/nowcasting/PreDiff/src/knowledge_alignment/alignment_net.py @@ -0,0 +1,119 @@ +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"alignment main model" +from typing import Dict, Any + +import mindspore +from mindspore import mint, ops + +from .alignment import NoisyCuboidTransformerEncoder + + +class AvgIntensityAlignment: + """ + A class for intensity alignment using a NoisyCuboidTransformerEncoder model to guide latent space adjustments. + """ + + def __init__( + self, + guide_scale: float = 1.0, + model_args: Dict[str, Any] = None, + model_ckpt_path: str = None, + ): + r""" + + Parameters + ---------- + alignment_type: str + guide_scale: float + model_type: str + model_args: Dict[str, Any] + model_ckpt_path: str + if not None, load the model from the checkpoint + """ + super().__init__() + self.guide_scale = guide_scale + if model_args is None: + model_args = {} + self.model = NoisyCuboidTransformerEncoder(**model_args) + self.load_ckpt(model_ckpt_path) + + def load_ckpt(self, model_ckpt_path): + if model_ckpt_path is not None: + param_dict = mindspore.load_checkpoint(model_ckpt_path) + param_not_load, _ = mindspore.load_param_into_net(self.model, param_dict) + print("NoisyCuboidTransformerEncoder param_not_load:", param_not_load) + + def get_sample_align_fn(self, sample_align_model): + """get_sample_align_fn""" + + def sample_align_fn(x, *args, **kwargs): + def forward_fn(x_in): + x_stop = ops.stop_gradient(x_in) + return sample_align_model(x_stop, *args, **kwargs) + + grad_fn = mindspore.grad(forward_fn, grad_position=0) + gradient = grad_fn(x) + return gradient + + return sample_align_fn + + def alignment_fn(self, zt, t, **kwargs): + r""" + transform the learned model to the final guidance \mathcal{F}. + + Parameters + ---------- + zt: ms.Tensor + noisy latent z + t: ms.Tensor + timestamp + y: ms.Tensor + context sequence in pixel space + zc: ms.Tensor + encoded context sequence in latente space + kwargs: Dict[str, Any] + auxiliary knowledge for guided generation + `avg_x_gt`: float is required. + Returns + ------- + ret: ms.Tensor + """ + pred = self.model(zt, t) + target = kwargs.get("avg_x_gt") + pred = pred.mean(axis=1) + ret = mint.linalg.vector_norm(pred - target, ord=2) + return ret + + def get_mean_shift(self, zt, t, y=None, zc=None, **kwargs): + r""" + Parameters + ---------- + zt: ms.Tensor + noisy latent z + t: ms.Tensor + timestamp + y: ms.Tensor + context sequence in pixel space + zc: ms.Tensor + encoded context sequence in latente space + Returns + ------- + ret: ms.Tensor + \nabla_zt U + """ + grad_fn = self.get_sample_align_fn(self.alignment_fn) + grad = grad_fn(zt, t, y=y, zc=zc, **kwargs) + return self.guide_scale * grad diff --git a/MindEarth/applications/nowcasting/PreDiff/src/sevir_dataset.py b/MindEarth/applications/nowcasting/PreDiff/src/sevir_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..d0fca3d1c9ce714a11e1a203e22fa6e4bfec2048 --- /dev/null +++ b/MindEarth/applications/nowcasting/PreDiff/src/sevir_dataset.py @@ -0,0 +1,1025 @@ +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"generate dataset" +import os +import datetime +from typing import Union, Sequence, Tuple +import h5py +import pandas as pd +from einops import rearrange +import numpy as np + +import mindspore as ms +import mindspore.dataset as ds +import mindspore.dataset.vision.transforms as vision +from mindspore import nn, ops, Tensor +from mindspore.dataset.vision import RandomRotation, Rotate +from mindspore.dataset.transforms import Compose + + +SEVIR_DATA_TYPES = ["vis", "ir069", "ir107", "vil", "lght"] +LIGHTING_FRAME_TIMES = np.arange(-120.0, 125.0, 5) * 60 +SEVIR_DATA_SHAPE = { + "lght": (48, 48), +} +PREPROCESS_SCALE_01 = { + "vis": 1, + "ir069": 1, + "ir107": 1, + "vil": 1 / 255, + "lght": 1, +} +PREPROCESS_OFFSET_01 = { + "vis": 0, + "ir069": 0, + "ir107": 0, + "vil": 0, + "lght": 0, +} + + +def path_splitall(path): + """ + Split a file path into all its components. + + Recursively splits the path into directory components and the final file name, + handling both absolute and relative paths across different OS conventions. + + Args: + path (str): Input file path to split + + Returns: + List[str]: List of path components from root to leaf + """ + allparts = [] + while 1: + parts = os.path.split(path) + if parts[0] == path: + allparts.insert(0, parts[0]) + break + elif parts[1] == path: + allparts.insert(0, parts[1]) + break + else: + path = parts[0] + allparts.insert(0, parts[1]) + return allparts + + +def change_layout(data, in_layout="NHWT", out_layout="NHWT"): + """ + Convert data layout between different dimension orderings. + + Handles layout transformations using einops.rearrange, with special handling + for 'C' (channel) dimensions which are treated as singleton dimensions. + + Args: + data (Tensor/ndarray): Input data to transform + in_layout (str): Current dimension order (e.g., "NHWT") + out_layout (str): Target dimension order (e.g., "THWC") + + Returns: + ndarray: Data in new layout with applied transformations + """ + if isinstance(data, ms.Tensor): + data = data.asnumpy() + in_layout = " ".join(in_layout.replace("C", "1")) + out_layout = " ".join(out_layout.replace("C", "1")) + data = rearrange(data, f"{in_layout} -> {out_layout}") + return data + + +class DatasetSEVIR: + """ + SEVIR Dataset class for weather event sequence data. + + Provides data loading and augmentation capabilities for SEVIR (Severe Weather Events Dataset) + with support for different temporal layouts and data preprocessing. + + Attributes: + layout (str): Output data layout configuration + sevir_dataloader (SEVIRDataLoader): Core data loading component + aug_pipeline (AugmentationPipeline): Data augmentation operations + """ + def __init__( + self, + seq_in: int = 25, + raw_seq_in: int = 49, + sample_mode: str = "sequent", + stride: int = 12, + layout: str = "THWC", + ori_layout: str = "NHWT", + split_mode: str = "uneven", + sevir_catalog: Union[str, pd.DataFrame] = None, + sevir_data_dir: str = None, + start_date: datetime.datetime = None, + end_date: datetime.datetime = None, + datetime_filter=None, + catalog_filter="default", + shuffle: bool = False, + shuffle_seed: int = 1, + output_type=np.float32, + preprocess: bool = True, + rescale_method: str = "01", + verbose: bool = False, + aug_mode: str = "0", + ): + super().__init__() + self.layout = layout.replace("C", "1") + self.sevir_dataloader = SEVIRDataLoader( + data_types=[ + "vil", + ], + seq_in=seq_in, + raw_seq_in=raw_seq_in, + sample_mode=sample_mode, + stride=stride, + batch_size=1, + layout=ori_layout, + num_shard=1, + rank=0, + split_mode=split_mode, + sevir_catalog=sevir_catalog, + sevir_data_dir=sevir_data_dir, + start_date=start_date, + end_date=end_date, + datetime_filter=datetime_filter, + catalog_filter=catalog_filter, + shuffle=shuffle, + shuffle_seed=shuffle_seed, + output_type=output_type, + preprocess=preprocess, + rescale_method=rescale_method, + verbose=verbose, + ) + self.aug_mode = aug_mode + self.aug_pipeline = AugmentationPipeline( + self.aug_mode, + self.layout, + ) + + def __getitem__(self, index): + """ + Get processed data sample by index. + + Performs data extraction, augmentation, and layout conversion. + + Args: + index (int): Sample index + + Returns: + ndarray: Processed data in specified layout + """ + data_dict = self.sevir_dataloader.extract_data(index=index) + data = data_dict["vil"] + if self.aug_pipeline is not None: + data = self.aug_pipeline(data_dict) + return data + + def __len__(self): + """len""" + return self.sevir_dataloader.__len__() + + +class SEVIRDataModule(nn.Cell): + """ + DataModule for SEVIR dataset. + + Manages dataset splits (train/val/test), data loading, and augmentation + for training diffusion models on weather event sequences. + + Attributes: + sevir_dir (str): Root directory of SEVIR dataset + batch_size (int): Data loader batch size + num_workers (int): Number of data loader workers + aug_mode (str): Data augmentation configuration + layout (str): Data layout configuration + """ + + def __init__( + self, + seq_in: int = 25, + sample_mode: str = "sequent", + stride: int = 12, + layout: str = "NTHWC", + output_type=np.float32, + preprocess: bool = True, + rescale_method: str = "01", + verbose: bool = False, + aug_mode: str = "0", + dataset_name: str = "sevir", + sevir_dir: str = None, + start_date: Tuple[int] = None, + train_val_split_date: Tuple[int] = (2019, 3, 20), + train_test_split_date: Tuple[int] = (2019, 6, 1), + end_date: Tuple[int] = None, + val_ratio: float = 0.1, + batch_size: int = 1, + num_workers: int = 1, + raw_seq_len: int = 25, + seed: int = 0, + ): + super().__init__() + self.sevir_dir = sevir_dir + self.aug_mode = aug_mode + self.seq_in = seq_in + self.sample_mode = sample_mode + self.stride = stride + self.layout = layout.replace("N", "") + self.output_type = output_type + self.preprocess = preprocess + self.rescale_method = rescale_method + self.verbose = verbose + self.aug_mode = aug_mode + self.batch_size = batch_size + self.num_workers = num_workers + self.seed = seed + self.dataset_name = dataset_name + self.sevir_dir = sevir_dir + self.catalog_path = os.path.join(sevir_dir, "CATALOG.csv") + self.raw_data_dir = os.path.join(sevir_dir, "data") + self.raw_seq_in = raw_seq_len + self.start_date = ( + datetime.datetime(*start_date) if start_date is not None else None + ) + self.train_test_split_date = ( + datetime.datetime(*train_test_split_date) + if train_test_split_date is not None + else None + ) + self.train_val_split_date = ( + datetime.datetime(*train_val_split_date) + if train_val_split_date is not None + else None + ) + self.end_date = datetime.datetime(*end_date) if end_date is not None else None + self.val_ratio = val_ratio + + def setup(self, stage=None) -> None: + """ + Prepare dataset splits for different stages. + + Creates train/val/test splits based on date ranges and configuration. + + Args: + stage (str): Current stage ("fit", "test", etc.) + """ + if stage in (None, "fit"): + print("train") + self.sevir_train_ori = DatasetSEVIR( + sevir_catalog=self.catalog_path, + sevir_data_dir=self.raw_data_dir, + raw_seq_in=self.raw_seq_in, + split_mode="uneven", + shuffle=False, + seq_in=self.seq_in, + stride=self.stride, + sample_mode=self.sample_mode, + layout=self.layout, + start_date=self.start_date, + end_date=self.train_val_split_date, + output_type=self.output_type, + preprocess=self.preprocess, + rescale_method=self.rescale_method, + verbose=self.verbose, + aug_mode=self.aug_mode, + ) + self.sevir_train = ds.GeneratorDataset( + source=self.sevir_train_ori, + column_names="vil", + shuffle=False, + num_parallel_workers=self.num_workers, + ) + self.sevir_train = self.sevir_train.batch(batch_size=self.batch_size) + + if stage in (None, "fit"): + print("val") + self.sevir_val = DatasetSEVIR( + sevir_catalog=self.catalog_path, + sevir_data_dir=self.raw_data_dir, + raw_seq_in=self.raw_seq_in, + split_mode="uneven", + shuffle=False, + seq_in=self.seq_in, + stride=self.stride, + sample_mode=self.sample_mode, + layout=self.layout, + start_date=self.train_val_split_date, + end_date=self.train_test_split_date, + output_type=self.output_type, + preprocess=self.preprocess, + rescale_method=self.rescale_method, + verbose=self.verbose, + aug_mode=self.aug_mode, + ) + self.sevir_val = ds.GeneratorDataset( + source=self.sevir_val, + column_names="vil", + shuffle=False, + num_parallel_workers=self.num_workers, + ) + self.sevir_val = self.sevir_val.batch(batch_size=self.batch_size) + + if stage in (None, "test"): + print("test") + self.sevir_test = DatasetSEVIR( + sevir_catalog=self.catalog_path, + sevir_data_dir=self.raw_data_dir, + raw_seq_in=self.raw_seq_in, + split_mode="uneven", + shuffle=False, + seq_in=self.seq_in, + stride=self.stride, + sample_mode=self.sample_mode, + layout=self.layout, + start_date=self.train_test_split_date, + end_date=self.end_date, + output_type=self.output_type, + preprocess=self.preprocess, + rescale_method=self.rescale_method, + verbose=self.verbose, + aug_mode=self.aug_mode, + ) + self.sevir_test = ds.GeneratorDataset( + source=self.sevir_test, + column_names="vil", + shuffle=False, + num_parallel_workers=self.num_workers, + ) + self.sevir_test = self.sevir_test.batch(batch_size=self.batch_size) + + @property + def num_train_samples(self): + """Get number of training samples""" + return len(self.sevir_train_ori) + + @property + def num_val_samples(self): + """Get number of validation samples""" + return len(self.sevir_val) + + @property + def num_test_samples(self): + """Get number of test samples""" + return len(self.sevir_test) + + +class SEVIRDataLoader: + r""" + DataLoader that loads SEVIR sequences, and spilts each event + into segments according to specified sequence length. + """ + + def __init__( + self, + data_types: Sequence[str] = None, + seq_in: int = 49, + raw_seq_in: int = 49, + sample_mode: str = "sequent", + stride: int = 12, + batch_size: int = 1, + layout: str = "NHWT", + num_shard: int = 1, + rank: int = 0, + split_mode: str = "uneven", + sevir_catalog: Union[str, pd.DataFrame] = None, + sevir_data_dir: str = None, + start_date: datetime.datetime = None, + end_date: datetime.datetime = None, + datetime_filter=None, + catalog_filter="default", + shuffle: bool = False, + shuffle_seed: int = 1, + output_type=np.float32, + preprocess: bool = True, + rescale_method: str = "01", + verbose: bool = False, + ): + super().__init__() + + # configs which should not be modified + self.lght_frame_times = LIGHTING_FRAME_TIMES + self.data_shape = SEVIR_DATA_SHAPE + + self.raw_seq_in = raw_seq_in + assert ( + seq_in <= self.raw_seq_in + ), f"seq_in must not be larger than raw_seq_in = {raw_seq_in}, got {seq_in}." + self.seq_in = seq_in + assert sample_mode in [ + "random", + "sequent", + ], f"Invalid sample_mode = {sample_mode}, must be 'random' or 'sequent'." + self.sample_mode = sample_mode + self.stride = stride + self.batch_size = batch_size + valid_layout = ("NHWT", "NTHW", "NTCHW", "NTHWC", "TNHW", "TNCHW") + if layout not in valid_layout: + raise ValueError( + f"Invalid layout = {layout}! Must be one of {valid_layout}." + ) + self.layout = layout + self.num_shard = num_shard + self.rank = rank + valid_split_mode = ("ceil", "floor", "uneven") + if split_mode not in valid_split_mode: + raise ValueError( + f"Invalid split_mode: {split_mode}! Must be one of {valid_split_mode}." + ) + self.split_mode = split_mode + self._samples = None + self._hdf_files = {} + self.data_types = data_types + if isinstance(sevir_catalog, str): + self.catalog = pd.read_csv( + sevir_catalog, parse_dates=["time_utc"], low_memory=False + ) + else: + self.catalog = sevir_catalog + self.sevir_data_dir = sevir_data_dir + self.datetime_filter = datetime_filter + self.catalog_filter = catalog_filter + self.start_date = start_date + self.end_date = end_date + self.shuffle = shuffle + self.shuffle_seed = int(shuffle_seed) + self.output_type = output_type + self.preprocess = preprocess + self.rescale_method = rescale_method + self.verbose = verbose + + if self.start_date is not None: + self.catalog = self.catalog[self.catalog.time_utc > self.start_date] + if self.end_date is not None: + self.catalog = self.catalog[self.catalog.time_utc <= self.end_date] + if self.datetime_filter: + self.catalog = self.catalog[self.datetime_filter(self.catalog.time_utc)] + + if self.catalog_filter is not None: + if self.catalog_filter == "default": + self.catalog_filter = lambda c: c.pct_missing == 0 + self.catalog = self.catalog[self.catalog_filter(self.catalog)] + + self._compute_samples() + print(self._samples.head(n=10)) + print("len", len(self._samples)) + self._open_files(verbose=self.verbose) + self.reset() + + def _compute_samples(self): + """ + Computes the list of samples in catalog to be used. This sets self._samples + """ + imgt = self.data_types + imgts = set(imgt) + filtcat = self.catalog[ + np.logical_or.reduce([self.catalog.img_type == i for i in imgt]) + ] + filtcat = filtcat.groupby("id").filter( + lambda x: imgts.issubset(set(x["img_type"])) + ) + filtcat = filtcat.groupby("id").filter(lambda x: x.shape[0] == len(imgt)) + self._samples = filtcat.groupby("id").apply( + lambda df: self._df_to_series(df, imgt) + ) + if self.shuffle: + self.shuffle_samples() + + def shuffle_samples(self): + """Shuffle the dataset samples using a fixed random seed for reproducibility.""" + self._samples = self._samples.sample(frac=1, random_state=self.shuffle_seed) + + def _df_to_series(self, df, imgt): + """Convert catalog DataFrame entries to structured format for multi-image types.""" + d = {} + df = df.set_index("img_type") + for i in imgt: + s = df.loc[i] + idx = s.file_index if i != "lght" else s.id + d.update({f"{i}_filename": [s.file_name], f"{i}_index": [idx]}) + + return pd.DataFrame(d) + + def _open_files(self, verbose=True): + """ + Opens HDF files + """ + imgt = self.data_types + hdf_filenames = [] + for t in imgt: + hdf_filenames += list(np.unique(self._samples[f"{t}_filename"].values)) + + print("hdf_filenames", hdf_filenames) + self._hdf_files = {} + for f in hdf_filenames: + print("Opening HDF5 file for reading", f) + if verbose: + print("Opening HDF5 file for reading", f) + self._hdf_files[f] = h5py.File(self.sevir_data_dir + "/" + f, "r") + print("f:", f) + print("self._hdf_files[f]:", self._hdf_files[f]) + + def close(self): + """ + Closes all open file handles + """ + for f in self._hdf_files: + self._hdf_files[f].close() + print("close: ", f) + self._hdf_files = {} + + @property + def num_seq_per_event(self): + """num seq per event""" + return 1 + (self.raw_seq_in - self.seq_in) // self.stride + + @property + def total_num_seq(self): + """ + The total number of sequences within each shard. + Notice that it is not the product of `self.num_seq_per_event` and `self.total_num_event`. + """ + return int(self.num_seq_per_event * self.num_event) + + @property + def total_num_event(self): + """ + The total number of events in the whole dataset, before split into different shards. + """ + return int(self._samples.shape[0]) + + @property + def start_event_idx(self): + """ + The event idx used in certain rank should satisfy event_idx >= start_event_idx + """ + return self.total_num_event // self.num_shard * self.rank + + @property + def end_event_idx(self): + """ + The event idx used in certain rank should satisfy event_idx < end_event_idx + + """ + if self.split_mode == "ceil": + last_start_event_idx = ( + self.total_num_event // self.num_shard * (self.num_shard - 1) + ) + num_event = self.total_num_event - last_start_event_idx + return self.start_event_idx + num_event + if self.split_mode == "floor": + return self.total_num_event // self.num_shard * (self.rank + 1) + if self.rank == self.num_shard - 1: + return self.total_num_event + return self.total_num_event // self.num_shard * (self.rank + 1) + + @property + def num_event(self): + """ + The number of events split into each rank + """ + return self.end_event_idx - self.start_event_idx + + def _read_data(self, row, data): + """ + Iteratively read data into data dict. Finally data[imgt] gets shape (batch_size, height, width, raw_seq_in). + + Parameters + ---------- + row + A series with fields IMGTYPE_filename, IMGTYPE_index, IMGTYPE_time_index. + data + Dict, data[imgt] is a data tensor with shape = (tmp_batch_size, height, width, raw_seq_in). + + Returns + ------- + data + Updated data. Updated shape = (tmp_batch_size + 1, height, width, raw_seq_in). + """ + imgtyps = np.unique([x.split("_")[0] for x in list(row.keys())]) + for t in imgtyps: + fname = row[f"{t}_filename"] + idx = row[f"{t}_index"] + t_slice = slice(0, None) + if t == "lght": + lght_data = self._hdf_files[fname][idx][:] + data_i = self._lght_to_grid(lght_data, t_slice) + else: + data_i = self._hdf_files[fname][t][idx : idx + 1, :, :, t_slice] + data[t] = ( + np.concatenate((data[t], data_i), axis=0) if (t in data) else data_i + ) + + return data + + def _lght_to_grid(self, data, t_slice=slice(0, None)): + """ + Converts Nx5 lightning data matrix into a 2D grid of pixel counts + """ + + out_size = ( + (*self.data_shape["lght"], len(self.lght_frame_times)) + if t_slice.stop is None + else (*self.data_shape["lght"], 1) + ) + if data.shape[0] == 0: + return np.zeros((1,) + out_size, dtype=np.float32) + + x, y = data[:, 3], data[:, 4] + m = np.logical_and.reduce([x >= 0, x < out_size[0], y >= 0, y < out_size[1]]) + data = data[m, :] + if data.shape[0] == 0: + return np.zeros((1,) + out_size, dtype=np.float32) + t = data[:, 0] + if t_slice.stop is not None: + if t_slice.stop > 0: + if t_slice.stop < len(self.lght_frame_times): + tm = np.logical_and( + t >= self.lght_frame_times[t_slice.stop - 1], + t < self.lght_frame_times[t_slice.stop], + ) + else: + tm = t >= self.lght_frame_times[-1] + else: + tm = np.logical_and( + t >= self.lght_frame_times[0], t < self.lght_frame_times[1] + ) + + data = data[tm, :] + z = np.zeros(data.shape[0], dtype=np.int64) + else: + z = np.digitize(t, self.lght_frame_times) - 1 + z[z == -1] = 0 + + x = data[:, 3].astype(np.int64) + y = data[:, 4].astype(np.int64) + + k = np.ravel_multi_index(np.array([y, x, z]), out_size) + n = np.bincount(k, minlength=np.prod(out_size)) + return np.reshape(n, out_size).astype(np.int16)[np.newaxis, :] + + @property + def sample_count(self): + """ + Record how many times self.__next__() is called. + """ + return self._sample_count + + @property + def _curr_event_idx(self): + return self.__curr_event_idx + + @property + def _curr_seq_idx(self): + """ + Used only when self.sample_mode == 'sequent' + """ + return self.__curr_seq_idx + + def _set__curr_event_idx(self, val): + self.__curr_event_idx = val + + def _set__curr_seq_idx(self, val): + """ + Used only when self.sample_mode == 'sequent' + """ + self.__curr_seq_idx = val + + def reset(self, shuffle: bool = None): + """reset""" + self._set__curr_event_idx(val=self.start_event_idx) + self._set__curr_seq_idx(0) + self._sample_count = 0 + if shuffle is None: + shuffle = self.shuffle + if shuffle: + self.shuffle_samples() + + def __len__(self): + """ + Used only when self.sample_mode == 'sequent' + """ + return self.total_num_seq // self.batch_size + + def _load_event_batch(self, event_idx, event_batch_size): + """ + Loads a selected batch of events (not batch of sequences) into memory. + + Parameters + ---------- + idx + event_batch_size + event_batch[i] = all_type_i_available_events[idx:idx + event_batch_size] + Returns + ------- + event_batch + list of event batches. + event_batch[i] is the event batch of the i-th data type. + Each event_batch[i] is a np.ndarray with shape = (event_batch_size, height, width, raw_seq_in) + """ + event_idx_slice_end = event_idx + event_batch_size + pad_size = 0 + if event_idx_slice_end > self.end_event_idx: + pad_size = event_idx_slice_end - self.end_event_idx + event_idx_slice_end = self.end_event_idx + pd_batch = self._samples.iloc[event_idx:event_idx_slice_end] + data = {} + for _, row in pd_batch.iterrows(): + data = self._read_data(row, data) + if pad_size > 0: + event_batch = [] + for t in self.data_types: + pad_shape = [ + pad_size, + ] + list(data[t].shape[1:]) + data_pad = np.concatenate( + ( + data[t].astype(self.output_type), + np.zeros(pad_shape, dtype=self.output_type), + ), + axis=0, + ) + event_batch.append(data_pad) + else: + event_batch = [data[t].astype(self.output_type) for t in self.data_types] + return event_batch + + + def extract_data(self, index): + """ + Extracts a batch of data without any processing. + + Parameters + ---------- + index + The index of the batch to sample. + + Returns + ------- + event_batch + The extracted data from the event batch without any processing. + """ + event_idx = (index * self.batch_size) // self.num_seq_per_event + seq_idx = (index * self.batch_size) % self.num_seq_per_event + num_sampled = 0 + sampled_idx_list = [] + while num_sampled < self.batch_size: + sampled_idx_list.append({"event_idx": event_idx, "seq_idx": seq_idx}) + seq_idx += 1 + if seq_idx >= self.num_seq_per_event: + event_idx += 1 + seq_idx = 0 + num_sampled += 1 + + start_event_idx = sampled_idx_list[0]["event_idx"] + event_batch_size = sampled_idx_list[-1]["event_idx"] - start_event_idx + 1 + + event_batch = self._load_event_batch( + event_idx=start_event_idx, event_batch_size=event_batch_size + ) + ret_dict = {} + for sampled_idx in sampled_idx_list: + batch_slice = [ + sampled_idx["event_idx"] - start_event_idx, + ] + seq_slice = slice( + sampled_idx["seq_idx"] * self.stride, + sampled_idx["seq_idx"] * self.stride + self.seq_in, + ) + for imgt_idx, imgt in enumerate(self.data_types): + sampled_seq = event_batch[imgt_idx][batch_slice, :, :, seq_slice] + if imgt in ret_dict: + ret_dict[imgt] = np.concatenate( + (ret_dict[imgt], sampled_seq), axis=0 + ) + else: + ret_dict.update({imgt: sampled_seq}) + + return ret_dict + + +class AugmentationPipeline: + """Data augmentation pipeline for multi-frame image processing. + """ + def __init__( + self, + aug_mode="0", + layout=None, + ): + self.layout = layout + self.aug_mode = aug_mode + + if aug_mode == "0": + self.aug = lambda x: x + elif self.aug_mode == "1": + self.aug = Compose( + [ + vision.RandomHorizontalFlip(), + vision.RandomVerticalFlip(), + RandomRotation(degrees=180), + ] + ) + elif aug_mode == "2": + self.aug = Compose( + [ + vision.RandomHorizontalFlip(), + vision.RandomVerticalFlip(), + FixedAngleRotation(angles=[0, 90, 180, 270]), + ] + ) + else: + raise NotImplementedError + + def rearrange_tensor(self, tensor, from_layout, to_layout): + """Permute and reshape tensor dimensions according to layout specifications.""" + return tensor.permute(*tuple(range(len(from_layout)))).reshape(to_layout) + + def __call__(self, data_dict): + """Apply augmentation pipeline to input data dictionary. + + Args: + data_dict (dict): Input data containing "vil" key with tensor data + + Returns: + ms.Tensor: Processed tensor with applied augmentations and layout conversion + """ + data = data_dict["vil"].squeeze(0) + if self.aug_mode != "0": + data = rearrange( + data, + "H W T -> T H W", + ) + data = self.aug(data) + data = rearrange(data, f"{' '.join('THW')} -> {' '.join(self.layout)}") + else: + data = rearrange( + data, + f"{' '.join('HWT')} -> {' '.join(self.layout)}", + ) + + return data + + +class FixedAngleRotation: + """Image augmentation for rotating images by fixed predefined angles. + + Args: + angles (List[int]): List of allowed rotation angles (degrees) + """ + def __init__(self, angles=None): + self.angles = angles + + def __call__(self, img): + """Apply random rotation from predefined angles. + + Args: + img (PIL.Image or mindspore.Tensor): Input image to transform + + Returns: + PIL.Image or mindspore.Tensor: Rotated image with same format as input + """ + angle = np.random.choice(self.angles) + return Rotate(angle)(img) + + +class SEVIRDataset: + """Base dataset class for processing SEVIR data with configurable preprocessing. + + Args: + data_types (Sequence[str], optional): + List of data types to process (e.g., ["vil", "lght"]). Defaults to SEVIR_DATA_TYPES. + layout (str, optional): + Tensor layout specification containing dimensions: + N - batch size + H - height + W - width + T - time/sequence length + C - channel + Defaults to "NHWT". + rescale_method (str, optional): + Data rescaling strategy identifier (e.g., "01" for 0-1 normalization). Defaults to "01". + """ + def __init__( + self, + data_types: Sequence[str] = None, + layout: str = "NHWT", + rescale_method: str = "01", + ): + super().__init__() + if data_types is None: + data_types = SEVIR_DATA_TYPES + else: + assert set(data_types).issubset(SEVIR_DATA_TYPES) + + self.layout = layout + self.data_types = data_types + self.rescale_method = rescale_method + + @staticmethod + def preprocess_data_dict(data_dict, data_types=None, layout="NHWT"): + """ + Parameterss + ---------- + data_dict: Dict[str, Union[np.ndarray, ms.Tensor]] + data_types: Sequence[str] + The data types that we want to rescale. This mainly excludes "mask" from preprocessing. + layout: str + consists of batch_size 'N', seq_in 'T', channel 'C', height 'H', width 'W' + Returns + ------- + data_dict: Dict[str, Union[np.ndarray, ms.Tensor]] + preprocessed data + """ + scale_dict = PREPROCESS_SCALE_01 + offset_dict = PREPROCESS_OFFSET_01 + if data_types is None: + data_types = data_dict.keys() + for key, data in data_dict.items(): + if key in data_types: + if isinstance(data, np.ndarray): + data = data.astype(np.float32) + elif isinstance(data, ms.Tensor): + data = data.float() + else: + raise TypeError + data = change_layout( + data=scale_dict[key] * (data + offset_dict[key]), + in_layout="NHWT", + out_layout=layout, + ) + data_dict[key] = data + return data_dict + + @staticmethod + def data_dict_to_tensor(data_dict, data_types=None): + """ + Convert each element in data_dict to ms.Tensor (copy without grad). + """ + ret_dict = {} + if data_types is None: + data_types = data_dict.keys() + for key, data in data_dict.items(): + if key in data_types: + if isinstance(data, ms.Tensor): + ret_dict[key] = data + elif isinstance(data, np.ndarray): + ret_dict[key] = Tensor.from_numpy(data) + else: + raise ValueError( + f"Invalid data type: {type(data)}. Should be ms.Tensor or np.ndarray" + ) + else: + ret_dict[key] = data + return ret_dict + + def process_data(self, data_dict): + """ + Processes the extracted data. + + Parameters + ---------- + data_dict + The dictionary containing the extracted data. + + Returns + ------- + processed_dict + The dictionary containing the processed data. + """ + split_tensors = data_dict.split(1, axis=0) + processed_tensors = [ + self.process_singledata(tensor) for tensor in split_tensors + ] + tensor_list = [] + for item in processed_tensors: + numpy_array = item["vil"] + tensor = Tensor(numpy_array) + tensor_list.append(tensor) + output_tensor = ops.Stack(axis=0)(tensor_list) + return output_tensor + + def process_singledata(self, singledata): + """process singledata""" + squeezed_tensor = ops.squeeze(singledata, 0) + singledata = {"vil": squeezed_tensor} + processed_dict = self.data_dict_to_tensor( + data_dict=singledata, data_types=self.data_types + ) + processed_dict = self.preprocess_data_dict( + data_dict=processed_dict, + data_types=self.data_types, + layout=self.layout, + ) + return processed_dict diff --git a/MindEarth/applications/nowcasting/PreDiff/src/utils.py b/MindEarth/applications/nowcasting/PreDiff/src/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..5727573a783b017ac215f4f3a5dcb53484505590 --- /dev/null +++ b/MindEarth/applications/nowcasting/PreDiff/src/utils.py @@ -0,0 +1,1062 @@ +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"all util" +import math +import os +import shutil +import re +from copy import deepcopy +from inspect import isfunction +from typing import Dict, Any, Callable, Optional, Sequence +import numpy as np +import cv2 +from einops import repeat, rearrange + +import mindspore as ms +from mindspore import ops, mint, nn, Parameter, Tensor +from mindspore.train.metrics.metric import Metric +from mindspore.common.initializer import ( + initializer, + One, + Zero, + HeNormal, + Uniform, + TruncatedNormal, +) +from mindearth.utils import create_logger + + +PREPROCESS_SCALE_01 = { + "vis": 1, + "ir069": 1, + "ir107": 1, + "vil": 1 / 255, + "lght": 1, +} +PREPROCESS_OFFSET_01 = { + "vis": 0, + "ir069": 0, + "ir107": 0, + "vil": 0, + "lght": 0, +} + + +class DiagonalGaussianDistribution(nn.Cell): + """Diagonal Gaussian distribution layer for variational autoencoders. + + This class represents a diagonal Gaussian distribution parameterized by mean and log-variance, + supporting sampling, KL divergence computation, and negative log-likelihood evaluation. + + Attributes: + mean (Tensor): Mean values of the distribution + logvar (Tensor): Clamped log-variance values + std (Tensor): Standard deviation derived from logvar + var (Tensor): Variance derived from logvar + deterministic (bool): Flag indicating deterministic sampling mode + """ + def __init__(self, parameters, deterministic=False): + super().__init__() + self.parameters = parameters + self.mean, self.logvar = ops.chunk(parameters, 2, axis=1) + self.logvar = ops.clamp(self.logvar, -30.0, 20.0) + + self.deterministic = deterministic + self.std = ops.exp(0.5 * self.logvar) + + self.var = ops.exp(self.logvar) + + if self.deterministic: + self.var = self.std = ops.zeros_like(self.mean) + + def sample(self): + """Generate a sample from the distribution. + + Returns: + Tensor: Sampled tensor with same shape as mean + + Notes: + - If deterministic=True, returns mean directly without noise + - Uses reparameterization trick for differentiable sampling + """ + sample = mint.randn(self.mean.shape) + + x = self.mean + self.std * sample + return x + + def kl(self, other=None): + """Compute KL divergence between this distribution and another or standard normal.""" + if self.deterministic: + return ms.Tensor([0.0]) + if other is None: + return 0.5 * ops.sum( + ops.pow(self.mean, 2) + self.var - 1.0 - self.logvar, dim=[1, 2, 3] + ) + return 0.5 * ops.sum( + ops.pow(self.mean - other.mean, 2) / other.var + + self.var / other.var + - 1.0 + - self.logvar + + other.logvar, + dim=[1, 2, 3], + ) + + def mode(self): + """Return the mode of the distribution (mean value).""" + return self.mean + + +def _threshold(target, pred, t_input): + """Apply thresholding to target and prediction tensors.""" + t = (target >= t_input).float() + p = (pred >= t_input).float() + is_nan = ops.logical_or(ops.isnan(target), ops.isnan(pred)) + t[is_nan] = 0 + p[is_nan] = 0 + return t, p + + +@staticmethod +def process_data_dict_back(data_dict, data_types=None): + """Rescale and offset data in dictionary using predefined parameters. + + Applies normalization using scale and offset values from global dictionaries. + + Args: + data_dict (dict): Dictionary containing data tensors + data_types (list, optional): Keys to process. Defaults to all keys in data_dict. + rescale (str, optional): Rescaling mode identifier. Defaults to "01". + + Returns: + dict: Processed data dictionary with normalized values + """ + scale_dict = PREPROCESS_SCALE_01 + offset_dict = PREPROCESS_OFFSET_01 + if data_types is None: + data_types = data_dict.keys() + for key in data_types: + data = data_dict[key] + data = data.float() / scale_dict[key] - offset_dict[key] + data_dict[key] = data + return data_dict + + +class SEVIRSkillScore(Metric): + """ + Class for calculating meteorological skill scores using threshold-based metrics. + This metric class computes performance metrics like CSI, POD, etc., + across multiple thresholds for weather prediction evaluation. + Args: + layout (str): Data dimension layout specification (default "NHWT") + mode (str): Operation mode affecting dimension handling ("0", "1", or "2") + seq_in (Optional[int]): Input sequence length (required for modes 1/2) + preprocess_type (str): Data preprocessing method ("sevir" or "sevir_pool*") + threshold_list (Sequence[int]): List of thresholds for binary classification + metrics_list (Sequence[str]): List of metrics to compute (csi, bias, sucr, pod) + eps (float): Small value to prevent division by zero + """ + def __init__( + self, + layout: str = "NHWT", + mode: str = "0", + seq_in: Optional[int] = None, + preprocess_type: str = "sevir", + threshold_list: Sequence[int] = (16, 74, 133, 160, 181, 219), + metrics_list: Sequence[str] = ("csi", "bias", "sucr", "pod"), + eps: float = 1e-4, + ): + super().__init__() + self.layout = layout + assert preprocess_type == "sevir" or preprocess_type.startswith("sevir_pool") + self.preprocess_type = preprocess_type + self.threshold_list = threshold_list + self.metrics_list = metrics_list + self.eps = eps + self.mode = mode + self.seq_in = seq_in + if mode in ("0",): + self.keep_seq_in_dim = False + state_shape = (len(self.threshold_list),) + elif mode in ("1", "2"): + self.keep_seq_in_dim = True + assert isinstance( + self.seq_in, int + ), "seq_in must be provided when we need to keep seq_in dim." + state_shape = (len(self.threshold_list), self.seq_in) + + else: + raise NotImplementedError(f"mode {mode} not supported!") + + self.hits = Parameter(ops.zeros(state_shape), name="hits") + self.misses = Parameter(ops.zeros(state_shape), name="misses") + self.fas = Parameter(ops.zeros(state_shape), name="fas") + + @property + def hits_misses_fas_reduce_dims(self): + """Dimensions to reduce when calculating metric statistics. + + Returns: + list[int]: List of dimensions to collapse during metric computation + """ + if not hasattr(self, "_hits_misses_fas_reduce_dims"): + seq_dim = self.layout.find("T") + self._hits_misses_fas_reduce_dims = list(range(len(self.layout))) + if self.keep_seq_in_dim: + self._hits_misses_fas_reduce_dims.pop(seq_dim) + return self._hits_misses_fas_reduce_dims + + def clear(self): + """Clear the internal states.""" + self.hits.set_data(mint.zeros_like(self.hits)) + self.misses.set_data(mint.zeros_like(self.misses)) + self.fas.set_data(mint.zeros_like(self.fas)) + + @staticmethod + def pod(hits, misses, _, eps): + """Probability of Detection""" + return hits / (hits + misses + eps) + + @staticmethod + def sucr(hits, _, fas, eps): + """Probability of hits""" + return hits / (hits + fas + eps) + + @staticmethod + def csi(hits, misses, fas, eps): + """critical success index""" + return hits / (hits + misses + fas + eps) + + @staticmethod + def bias(hits, misses, fas, eps): + """Bias score""" + bias = (hits + fas) / (hits + misses + eps) + logbias = ops.pow(bias / ops.log(Tensor(2.0)), 2.0) + return logbias + + def calc_seq_hits_misses_fas(self, pred, target, threshold): + """Calculate contingency table statistics for given threshold. + + Args: + pred (Tensor): Model prediction tensor + target (Tensor): Ground truth tensor + threshold (int): Threshold value for binarization + + Returns: + tuple[Tensor, Tensor, Tensor]: Hits, misses, false alarms + """ + t, p = _threshold(target, pred, threshold) + hits = ops.sum(t * p, dim=self.hits_misses_fas_reduce_dims).int() + misses = ops.sum(t * (1 - p), dim=self.hits_misses_fas_reduce_dims).int() + fas = ops.sum((1 - t) * p, dim=self.hits_misses_fas_reduce_dims).int() + return hits, misses, fas + + def preprocess(self, pred, target): + """Apply data preprocessing based on configuration. + + Handles SEVIR-specific normalization and optional spatial pooling. + + Args: + pred (Tensor): Raw model predictions + target (Tensor): Raw ground truth data + + Returns: + tuple[Tensor, Tensor]: Processed prediction and target tensors + """ + if self.preprocess_type == "sevir": + pred = process_data_dict_back(data_dict={"vil": pred.float()})["vil"] + target = process_data_dict_back(data_dict={"vil": target.float()})["vil"] + elif self.preprocess_type.startswith("sevir_pool"): + pred = process_data_dict_back(data_dict={"vil": pred.float()})["vil"] + target = process_data_dict_back(data_dict={"vil": target.float()})["vil"] + self.pool_scale = int(re.search(r"\d+", self.preprocess_type).group()) + batch_size = target.shape[0] + pred = rearrange( + pred, f"{self.einops_layout} -> {self.einops_spatial_layout}" + ) + target = rearrange( + target, f"{self.einops_layout} -> {self.einops_spatial_layout}" + ) + max_pool = nn.MaxPool2d( + kernel_size=self.pool_scale, stride=self.pool_scale, pad_mode="pad" + ) + pred = max_pool(pred) + target = max_pool(target) + pred = rearrange( + pred, + f"{self.einops_spatial_layout} -> {self.einops_layout}", + N=batch_size, + ) + target = rearrange( + target, + f"{self.einops_spatial_layout} -> {self.einops_layout}", + N=batch_size, + ) + else: + raise NotImplementedError + return pred, target + + def update(self, pred: Tensor, target: Tensor): + """Update metric statistics with new batch of predictions.""" + pred, target = self.preprocess(pred, target) + for i, threshold in enumerate(self.threshold_list): + hits, misses, fas = self.calc_seq_hits_misses_fas(pred, target, threshold) + self.hits[i] += hits + self.misses[i] += misses + self.fas[i] += fas + + def eval(self): + """Compute final metric scores across all thresholds.""" + metrics_dict = { + "pod": self.pod, + "csi": self.csi, + "sucr": self.sucr, + "bias": self.bias, + } + ret = {} + for threshold in self.threshold_list: + ret[threshold] = {} + ret["avg"] = {} + for metrics in self.metrics_list: + if self.keep_seq_in_dim: + score_avg = np.zeros((self.seq_in,)) + else: + score_avg = 0 + scores = metrics_dict[metrics](self.hits, self.misses, self.fas, self.eps) + scores = scores.asnumpy() + for i, threshold in enumerate(self.threshold_list): + if self.keep_seq_in_dim: + score = scores[i] + else: + score = scores[i].item() + if self.mode in ("0", "1"): + ret[threshold][metrics] = score + elif self.mode in ("2",): + ret[threshold][metrics] = np.mean(score).item() + else: + raise NotImplementedError + score_avg += score + score_avg /= len(self.threshold_list) + if self.mode in ("0", "1"): + ret["avg"][metrics] = score_avg + elif self.mode in ("2",): + ret["avg"][metrics] = np.mean(score_avg).item() + else: + raise NotImplementedError + return ret + + +def make_beta_schedule( + schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3 +): + """Generate beta schedule for diffusion models. + + Supports linear, cosine, sqrt_linear and sqrt schedules. + + Args: + schedule (str): Schedule type ("linear", "cosine", etc.) + n_timestep (int): Number of time steps + linear_start (float): Linear schedule start value + linear_end (float): Linear schedule end value + cosine_s (float): Cosine schedule shift parameter + + Returns: + Tensor: Beta values for each time step + """ + if schedule == "linear": + betas = ( + mint.linspace( + linear_start**0.5, linear_end**0.5, n_timestep, dtype=ms.float64 + ) + ** 2 + ) + + elif schedule == "cosine": + timesteps = ops.arange(n_timestep + 1, dtype=ms.float64) / n_timestep + cosine_s + alphas = timesteps / (1 + cosine_s) * np.pi / 2 + alphas = ops.cos(alphas).pow(2) + alphas = alphas / alphas[0] + betas = 1 - alphas[1:] / alphas[:-1] + betas = np.clip(betas, a_min=0, a_max=0.999) + + elif schedule == "sqrt_linear": + betas = ops.linspace(linear_start, linear_end, n_timestep) + elif schedule == "sqrt": + betas = ops.linspace(linear_start, linear_end, n_timestep) ** 0.5 + else: + raise ValueError(f"schedule '{schedule}' unknown.") + return betas.asnumpy() + + +def extract_into_tensor(a, t, x_shape, batch_axis=0): + """Extract tensor elements and reshape to match target dimensions.""" + batch_size = t.shape[0] + out = a.gather_elements(-1, t) + out_shape = [ + 1, + ] * len(x_shape) + out_shape[batch_axis] = batch_size + return out.reshape(out_shape) + + +def noise_like(shape): + """Generate random noise tensor matching given shape.""" + return ops.randn(shape) + + +def default(val, d): + """Return val if present, otherwise resolve default value.""" + if val is not None: + return val + return d() if isfunction(d) else d + + +def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): + """ + Create sinusoidal timestep embeddings. + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an [N x dim] Tensor of positional embeddings. + """ + if not repeat_only: + half = dim // 2 + freqs = ops.exp( + -math.log(max_period) + * ops.arange(start=0, end=half, dtype=ms.float32) + / half + ) + args = timesteps[:, None].float() * freqs[None] + embedding = ops.cat([ops.cos(args), ops.sin(args)], axis=-1) + if dim % 2: + embedding = ops.cat([embedding, ops.zeros_like(embedding[:, :1])], axis=-1) + else: + embedding = repeat(timesteps, "b -> b d", d=dim) + return embedding + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for param in module.trainable_params(): + param.set_data(Zero()(shape=param.shape, dtype=param.dtype)) + return module + + +def normalization(channels): + """ + Make a standard normalization layer. + :param channels: number of input channels. + :return: an nn.Module for normalization. + """ + num_groups = min(32, channels) + return nn.GroupNorm(num_groups, channels) + + +def conv_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D convolution module. + """ + if dims == 1: + return nn.Conv1d(*args, **kwargs, pad_mode="pad", has_bias=True) + if dims == 2: + return nn.Conv2d(*args, **kwargs, pad_mode="pad", has_bias=True) + return mint.nn.Conv3d(*args, **kwargs) + + + +def linear(*args, **kwargs): + """ + Create a linear module. + """ + return nn.Dense(*args, **kwargs) + + +def avg_pool_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D average pooling module. + """ + if dims == 1: + return nn.AvgPool1d(*args, **kwargs) + if dims == 2: + return mint.nn.AvgPool2d(*args, **kwargs) + if dims == 3: + return nn.AvgPool3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + + +def round_to(dat, c): + """round to""" + return dat + (dat - dat % c) % c + + +def get_activation(act, inplace=False, **kwargs): + """ + + Parameters + ---------- + act + Name of the activation + inplace + Whether to perform inplace activation + + Returns + ------- + activation_layer + The activation + """ + if act is None: + return lambda x: x + if isinstance(act, str): + if act == "leaky": + negative_slope = kwargs.get("negative_slope", 0.1) + return nn.LeakyReLU(negative_slope, inplace=inplace) + if act == "identity": + return nn.Identity() + if act == "elu": + return nn.ELU(inplace=inplace) + if act == "gelu": + return nn.GELU(approximate=False) + if act == "relu": + return nn.ReLU() + if act == "sigmoid": + return nn.Sigmoid() + if act == "tanh": + return nn.Tanh() + if act in ('softrelu', 'softplus'): + return ops.Softplus() + if act == "softsign": + return nn.Softsign() + raise NotImplementedError('act="{}" is not supported. ') + return act + + +def get_norm_layer( + norm_type: str = "layer_norm", + axis: int = -1, + epsilon: float = 1e-5, + in_channels: int = 0, + **kwargs, +): + """Get the normalization layer based on the provided type + + Parameters + ---------- + norm_type + The type of the layer normalization from ['layer_norm'] + axis + The axis to normalize the + epsilon + The epsilon of the normalization layer + in_channels + Input channel + + Returns + ------- + norm_layer + The layer normalization layer + """ + if isinstance(norm_type, str): + if norm_type == "layer_norm": + assert in_channels > 0 + assert axis == -1 + norm_layer = nn.LayerNorm( + normalized_shape=[in_channels], epsilon=epsilon, **kwargs + ) + else: + raise NotImplementedError("norm_type={} is not supported".format(norm_type)) + return norm_layer + if norm_type is None: + return nn.Identity() + raise NotImplementedError("The type of normalization must be str") + + +def generalize_padding(x, pad_t, pad_h, pad_w, padding_type, t_pad_left=False): + """ + + Parameters + ---------- + x + Shape (B, T, H, W, C) + pad_t + pad_h + pad_w + padding_type + t_pad_left + + Returns + ------- + out + The result after padding the x. Shape will be (B, T + pad_t, H + pad_h, W + pad_w, C) + """ + if pad_t == 0 and pad_h == 0 and pad_w == 0: + return x + + assert padding_type in ["zeros", "ignore", "nearest"] + _, t, h, w, _ = x.shape + + if padding_type == "nearest": + return ops.interpolate( + x.permute(0, 4, 1, 2, 3), size=(t + pad_t, h + pad_h, w + pad_w) + ).permute(0, 2, 3, 4, 1) + if t_pad_left: + return ops.pad(x, (0, 0, 0, pad_w, 0, pad_h, pad_t, 0)) + return ops.pad(x, (0, 0, 0, pad_w, 0, pad_h, 0, pad_t)) + + +def generalize_unpadding(x, pad_t, pad_h, pad_w, padding_type): + """Removes padding from a 5D tensor based on specified padding type and dimensions. + + Args: + x (Tensor): Input tensor with shape (batch, time, height, width, channels). + pad_t (int): Number of time steps to remove from the end. + pad_h (int): Number of height units to remove from the end. + pad_w (int): Number of width units to remove from the end. + padding_type (str): Type of padding removal method ("zeros", "ignore", "nearest"). + + Returns: + Tensor: Processed tensor with padding removed according to specified method. + + Raises: + AssertionError: If invalid padding_type is provided. + """ + assert padding_type in ["zeros", "ignore", "nearest"] + _, t, h, w, _ = x.shape + if pad_t == 0 and pad_h == 0 and pad_w == 0: + return x + + if padding_type == "nearest": + return ops.interpolate( + x.permute(0, 4, 1, 2, 3), size=(t - pad_t, h - pad_h, w - pad_w) + ).permute(0, 2, 3, 4, 1) + return x[:, : (t - pad_t), : (h - pad_h), : (w - pad_w), :] + + +def _calculate_fan_in_and_fan_out(parameter): + """Calculates fan_in and fan_out values for neural network weight initialization.""" + dimensions = parameter.ndim + if dimensions < 2: + raise ValueError( + "Fan in and fan out can not be computed for parameter with fewer than 2 dimensions" + ) + num_input_fmaps = parameter.shape[1] + num_output_fmaps = parameter.shape[0] + receptive_field_size = 1 + if dimensions > 2: + for s in parameter.shape[2:]: + receptive_field_size *= s + fan_in = num_input_fmaps * receptive_field_size + fan_out = num_output_fmaps * receptive_field_size + return fan_in, fan_out + + +def apply_initialization( + cell, linear_mode="0", conv_mode="0", norm_mode="0", embed_mode="0" +): + """Applies parameter initialization strategies to neural network layers. + + Args: + cell (nn.Cell): Neural network layer to initialize. + linear_mode (str): Initialization mode for dense layers ("0", "1", "2"). + conv_mode (str): Initialization mode for convolutional layers ("0", "1", "2"). + norm_mode (str): Initialization mode for normalization layers ("0"). + embed_mode (str): Initialization mode for embedding layers ("0"). + + Raises: + NotImplementedError: If unsupported initialization mode is requested. + """ + if isinstance(cell, nn.Dense): + if linear_mode in ("0",): + cell.weight.set_data( + initializer( + HeNormal(mode="fan_in", nonlinearity="linear"), + cell.weight.shape, + cell.weight.dtype, + ) + ) + elif linear_mode in ("1",): + cell.weight.set_data( + initializer.initializer( + HeNormal(mode="fan_out", nonlinearity="leaky_relu"), + cell.weight.shape, + cell.weight.dtype, + ) + ) + elif linear_mode in ("2",): + zeros_tensor = ops.zeros(cell.weight.shape, cell.weight.dtype) + cell.weight.set_data(zeros_tensor) + else: + raise NotImplementedError + if hasattr(cell, "bias") and cell.bias is not None: + zeros_tensor = ops.zeros(cell.bias.shape, cell.bias.dtype) + cell.bias.set_data(zeros_tensor) + + elif isinstance( + cell, (nn.Conv2d, nn.Conv3d, nn.Conv2dTranspose, nn.Conv3dTranspose) + ): + if conv_mode in ("0",): + cell.weight.set_data( + initializer( + HeNormal( + negative_slope=math.sqrt(5), mode="fan_out", nonlinearity="relu" + ), + cell.weight.shape, + cell.weight.dtype, + ) + ) + if cell.has_bias: + fan_in, _ = _calculate_fan_in_and_fan_out(cell.weight) + if fan_in != 0: + bound = 1 / math.sqrt(fan_in) + cell.bias.set_data( + initializer(Uniform(bound), cell.bias.shape, cell.bias.dtype) + ) + elif conv_mode in ("1",): + cell.weight.set_data( + initializer( + HeNormal( + mode="fan_out", nonlinearity="leaky_relu", negative_slope=0.1 + ), + cell.weight.shape, + cell.weight.dtype, + ) + ) + if hasattr(mcell, "bias") and mcell.bias is not None: + cell.bias.set_data( + initializer(Zero(), cell.bias.shape, cell.bias.dtype) + ) + elif conv_mode in ("2",): + cell.weight.set_data( + initializer(Zero(), cell.weight.shape, cell.weight.dtype) + ) + if hasattr(m, "bias") and m.bias is not None: + cell.bias.set_data( + initializer(Zero(), cell.bias.shape, cell.bias.dtype) + ) + else: + raise NotImplementedError + + elif isinstance(cell, nn.GroupNorm): + if norm_mode in ("0",): + if cell.gamma is not None: + cell.gamma.set_data( + initializer(One(), cell.gamma.shape, cell.gamma.dtype) + ) + if cell.beta is not None: + cell.beta.set_data( + initializer(Zero(), cell.beta.shape, cell.beta.dtype) + ) + else: + raise NotImplementedError("Normalization mode not supported") + elif isinstance(cell, nn.Embedding): + if embed_mode == "0": + cell.embedding_table.set_data( + initializer( + TruncatedNormal(sigma=0.02), + cell.embedding_table.shape, + cell.embedding_table.dtype, + ) + ) + else: + raise NotImplementedError + else: + pass + + +def prepare_output_directory(base_config, device_id): + """Creates/updates output directory for experiment results. + + Args: + base_config (dict): Configuration dictionary containing directory paths. + device_id (int): Device identifier for directory naming. + + Returns: + str: Path to the created/updated output directory. + + Raises: + OSError: If directory operations fail unexpectedly. + """ + output_path = os.path.join( + base_config["summary"]["summary_dir"], f"single_device{device_id}" + ) + + try: + if os.path.exists(output_path): + shutil.rmtree(output_path) + print(f"Cleared previous output directory: {output_path}") + os.makedirs(output_path, exist_ok=True) + except OSError as e: + print(f"Directory operation failed: {e}", exc_info=True) + raise + base_config["summary"]["summary_dir"] = output_path + return output_path + + +def configure_logging_system(output_dir, config): + """Sets up logging system for the application. + + Args: + output_dir (str): Directory where logs should be stored. + config (dict): Configuration dictionary containing experiment parameters. + + Returns: + Logger: Configured logger instance. + """ + logger = create_logger(path=os.path.join(output_dir, "results.log")) + logger.info(f"Process ID: {os.getpid()}") + logger.info(config["summary"]) + return logger + + +def prepare_dataset(config, module): + """Initializes and prepares the dataset for training/evaluation. + + Args: + config (dict): Configuration dictionary with dataset parameters. + SEVIRPLModule (Module): Data module class for dataset handling. + + Returns: + tuple: (DataModule, total_num_steps) containing initialized data module and total training steps. + + Raises: + ValueError: If configuration is not provided. + """ + if config is not None: + dataset_cfg = config["data"] + total_batch_size = config["optim"]["total_batch_size"] + micro_batch_size = config["optim"]["micro_batch_size"] + max_epochs = config["optim"]["max_epochs"] + else: + raise ValueError("config is required but not provided") + dm = module.get_sevir_datamodule( + dataset_cfg=dataset_cfg, + micro_batch_size=micro_batch_size, + num_workers=8, + ) + dm.setup() + total_num_steps = module.get_total_num_steps( + epoch=max_epochs, + num_samples=dm.num_train_samples, + total_batch_size=total_batch_size, + ) + return dm, total_num_steps + + +def warmup_lambda(warmup_steps, min_lr_ratio=0.1): + """Creates a learning rate warmup schedule as a lambda function. + + Args: + warmup_steps (int): Number of steps for the warmup phase. + min_lr_ratio (float): Minimum learning rate ratio at the start of training. + + Returns: + function: Lambda function that calculates the warmup multiplier based on current step. + """ + def ret_lambda(epoch): + if epoch <= warmup_steps: + return min_lr_ratio + (1.0 - min_lr_ratio) * epoch / warmup_steps + return 1.0 + + return ret_lambda + + +def get_loss_fn(loss: str = "l2") -> Callable: + """ + Returns a loss function based on the provided loss type. + + Args: + loss (str): Type of loss function. Default is "l2". + + Returns: + Callable: A loss function corresponding to the provided loss type. + """ + if loss in ("l2", "mse"): + return nn.MSELoss() + return nn.L1Loss() + + +def disabled_train(self): + """Overwrite model.train with this function to make sure train/eval mode + does not change anymore.""" + return self + + +def disable_train(model: nn.Cell): + """ + Disable training to avoid error when used in pl.LightningModule + """ + model.set_train(False) + model.train = disabled_train + return model + + +def layout_to_in_out_slice(layout, t_in, t_out=None): + """layout_to_in_out_slice""" + t_axis = layout.find("T") + num_axes = len(layout) + in_slice = [ + slice(None, None), + ] * num_axes + out_slice = deepcopy(in_slice) + in_slice[t_axis] = slice(None, t_in) + if t_out is None: + out_slice[t_axis] = slice(t_in, None) + else: + out_slice[t_axis] = slice(t_in, t_in + t_out) + return in_slice, out_slice + + +def parse_layout_shape(layout: str) -> Dict[str, Any]: + r""" + + Parameters + ---------- + layout: str + e.g., "NTHWC", "NHWC". + + Returns + ------- + ret: Dict + """ + batch_axis = layout.find("N") + t_axis = layout.find("T") + h_axis = layout.find("H") + w_axis = layout.find("W") + c_axis = layout.find("C") + return { + "batch_axis": batch_axis, + "t_axis": t_axis, + "h_axis": h_axis, + "w_axis": w_axis, + "c_axis": c_axis, + } + + +def ssim(img1, img2): + """Compute Structural Similarity Index (SSIM) between two images. + + Args: + img1 (np.ndarray): First input image (grayscale or single-channel), shape (H, W) + img2 (np.ndarray): Second input image with identical shape to img1 + + Returns: + float: SSIM value between 0 (completely dissimilar) and 1 (perfect similarity) + + Notes: + - Uses 11x11 Gaussian window with σ=1.5 for weighted filtering + - Follows the standard SSIM formulation with constants c1=0.0001, c2=0.0009 + - Computes valid convolution regions (edges truncated by kernel size) + """ + c1 = 0.01**2 + c2 = 0.03**2 + img1 = img1.astype(np.float64) + img2 = img2.astype(np.float64) + kernel = cv2.getGaussianKernel(11, 1.5) + window = np.outer(kernel, kernel.transpose()) + mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid + mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5] + mu1_sq = mu1**2 + mu2_sq = mu2**2 + mu1_mu2 = mu1 * mu2 + sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq + sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq + sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2 + ssim_map = ((2 * mu1_mu2 + c1) * (2 * sigma12 + c2)) / ( + (mu1_sq + mu2_sq + c1) * (sigma1_sq + sigma2_sq + c2) + ) + return ssim_map.mean() + + +def calculate_ssim_function(img1, img2): + """calculate ssim function""" + if not img1.shape == img2.shape: + raise ValueError("Input images must have the same dimensions.") + if img1.ndim == 2: + return ssim(img1, img2) + if img1.ndim == 3: + if img1.shape[0] == 3: + ssims = [] + for i in range(3): + ssims.append(ssim(img1[i], img2[i])) + return np.array(ssims).mean() + if img1.shape[0] == 1: + return ssim(np.squeeze(img1), np.squeeze(img2)) + raise ValueError("Wrong input image dimensions.") + + + + +def calculate_ssim(videos1, videos2): + """Calculate Structural Similarity Index (SSIM) between two video sequences across all timestamps. + + Args: + videos1 (Tensor or np.ndarray): First video sequence with shape (batch_size, time_steps, + height, width, channels) + videos2 (Tensor or np.ndarray): Second video sequence with identical shape to videos1 + + Returns: + dict[int, float]: Dictionary where keys are timestamp indices and values are the mean SSIM values + across all batches for that timestamp + + Raises: + AssertionError: If input video tensors have different shapes + """ + ssim_results = [] + for video_num in range(videos1.shape[0]): + video1 = videos1[video_num] + video2 = videos2[video_num] + ssim_results_of_a_video = [] + for clip_timestamp in range(len(video1)): + img1 = video1[clip_timestamp] + img2 = video2[clip_timestamp] + ssim_results_of_a_video.append(calculate_ssim_function(img1, img2)) + ssim_results.append(ssim_results_of_a_video) + ssim_results = np.array(ssim_results) + ssim_score = {} + for clip_timestamp in range(len(video1)): + ssim_score[clip_timestamp] = np.mean(ssim_results[:, clip_timestamp]) + + return ssim_score + + +def init_model(module, config, mode): + """Initialize model with ckpt""" + summary_params = config.get("summary") + module.main_model.set_train(True) + if mode != "train": + summary_params["load_ckpt"] = "True" + module.main_model.set_train(False) + if summary_params["load_ckpt"]: + params = ms.load_checkpoint(summary_params.get("ckpt_path")) + ms.load_param_into_net( + module.main_model, params + ) + return module + +def self_axial(input_shape): + """Axial attention implementation from "Axial-Deeplab: + Efficient Convolutional Neural Networks for Semantic Segmentation" + Args: + input_shape (tuple): Input tensor shape (T, H, W, C). + Returns: + tuple: Axial attention parameters with separate temporal/spatial cuboids. + """ + t, h, w, _ = input_shape + cuboid_size = [(t, 1, 1), (1, h, 1), (1, 1, w)] + strategy = [("l", "l", "l"), ("l", "l", "l"), ("l", "l", "l")] + shift_size = [(0, 0, 0), (0, 0, 0), (0, 0, 0)] + return cuboid_size, strategy, shift_size diff --git a/MindEarth/applications/nowcasting/PreDiff/src/vae/__init__.py b/MindEarth/applications/nowcasting/PreDiff/src/vae/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d8769bc9da95c1278d051eedb8d94c3536d44a8b --- /dev/null +++ b/MindEarth/applications/nowcasting/PreDiff/src/vae/__init__.py @@ -0,0 +1,20 @@ +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this filepio[] except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""init""" +from .autoencoder_kl import AutoencoderKL + +__all__ = [ + "AutoencoderKL", +] diff --git a/MindEarth/applications/nowcasting/PreDiff/src/vae/autoencoder_kl.py b/MindEarth/applications/nowcasting/PreDiff/src/vae/autoencoder_kl.py new file mode 100644 index 0000000000000000000000000000000000000000..63d0b9d7e8b85f7408c3a1a50d2418450bb1b6a5 --- /dev/null +++ b/MindEarth/applications/nowcasting/PreDiff/src/vae/autoencoder_kl.py @@ -0,0 +1,378 @@ +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"vae base class" +from typing import Tuple + +import mindspore as ms +import mindspore.nn as nn +import mindspore.ops as ops +import mindspore.mint as mint + +from src.utils import DiagonalGaussianDistribution +from .unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block + + +class Encoder(nn.Cell): + """ + A class representing an encoder network for image encoding. + + Args: + in_channels (int): Number of input image channels (default: 3) + out_channels (int): Number of output image channels (default: 3) + down_block_types (tuple): Types of downsampling blocks (default: ("DownEncoderBlock2D",)) + block_out_channels (tuple): Output channels for each downsampling block (default: (64,)) + layers_per_block (int): Number of layers per downsampling block (default: 2) + norm_num_groups (int): Number of groups for group normalization (default: 32) + act_fn (str): Activation function type (default: "silu") + double_z (bool): Whether to double output channels (default: True) + + Returns: + None + """ + def __init__( + self, + in_channels=3, + out_channels=3, + down_block_types=("DownEncoderBlock2D",), + block_out_channels=(64,), + layers_per_block=2, + norm_num_groups=32, + act_fn="silu", + double_z=True, + ): + super().__init__() + self.layers_per_block = layers_per_block + self.conv_in = nn.Conv2d( + in_channels, + block_out_channels[0], + kernel_size=3, + stride=1, + padding=1, + has_bias=True, + pad_mode="pad", + ) + + self.mid_block = None + self.down_blocks = nn.CellList([]) + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + down_block = get_down_block( + down_block_type, + num_layers=self.layers_per_block, + in_channels=input_channel, + out_channels=output_channel, + add_downsample=not is_final_block, + resnet_eps=1e-6, + downsample_padding=0, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + attn_num_head_channels=None, + temb_channels=None, + ) + self.down_blocks.append(down_block) + self.mid_block = UNetMidBlock2D( + in_channels=block_out_channels[-1], + resnet_eps=1e-6, + resnet_act_fn=act_fn, + output_scale_factor=1, + resnet_time_scale_shift="default", + attn_num_head_channels=None, + resnet_groups=norm_num_groups, + temb_channels=None, + ) + self.conv_norm_out = nn.GroupNorm( + num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6 + ) + self.conv_act = nn.SiLU() + + conv_out_channels = 2 * out_channels if double_z else out_channels + self.conv_out = nn.Conv2d( + block_out_channels[-1], + conv_out_channels, + 3, + pad_mode="pad", + padding=1, + has_bias=True, + ) + + def construct(self, x): + """ + Forward pass through the encoder network. + + Args: + x (Tensor): Input image tensor + + Returns: + Tensor: Encoded output tensor + """ + + sample = self.conv_in(x) + + for _, down_block in enumerate(self.down_blocks): + sample = down_block(sample) + + sample = self.mid_block(sample) + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + return sample + + +class Decoder(nn.Cell): + """ + Decoder class for the decoding process in image generation tasks. + + Args: + in_channels (int): Number of input channels, defaults to 3. + out_channels (int): Number of output channels, defaults to 3. + up_block_types (tuple): Types of upsample blocks, defaults to ("UpDecoderBlock2D",). + block_out_channels (tuple): Output channels for each block, defaults to (64,). + layers_per_block (int): Number of layers per block, defaults to 2. + norm_num_groups (int): Number of groups for normalization, defaults to 32. + act_fn (str): Activation function type, defaults to "silu". + + Attributes: + layers_per_block (int): Number of layers per block. + conv_in (nn.Conv2d): Input convolution layer. + mid_block (UNetMidBlock2D): Middle block. + up_blocks (nn.CellList): List of upsample blocks. + conv_norm_out (nn.GroupNorm): Output normalization layer. + conv_act (nn.SiLU): Output activation layer. + conv_out (nn.Conv2d): Output convolution layer. + """ + def __init__( + self, + in_channels=3, + out_channels=3, + up_block_types=("UpDecoderBlock2D",), + block_out_channels=(64,), + layers_per_block=2, + norm_num_groups=32, + act_fn="silu", + ): + super().__init__() + self.layers_per_block = layers_per_block + + self.conv_in = nn.Conv2d( + in_channels, + block_out_channels[-1], + kernel_size=3, + stride=1, + padding=1, + pad_mode="pad", + has_bias=True, + ) + + self.mid_block = None + self.up_blocks = nn.CellList([]) + self.mid_block = UNetMidBlock2D( + in_channels=block_out_channels[-1], + resnet_eps=1e-6, + resnet_act_fn=act_fn, + output_scale_factor=1, + resnet_time_scale_shift="default", + attn_num_head_channels=None, + resnet_groups=norm_num_groups, + temb_channels=None, + ) + reversed_block_out_channels = list(reversed(block_out_channels)) + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + + is_final_block = i == len(block_out_channels) - 1 + + up_block = get_up_block( + up_block_type, + num_layers=self.layers_per_block + 1, + in_channels=prev_output_channel, + out_channels=output_channel, + prev_output_channel=None, + add_upsample=not is_final_block, + resnet_eps=1e-6, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + attn_num_head_channels=None, + temb_channels=None, + ) + self.up_blocks.append(up_block) + prev_output_channel = output_channel + self.conv_norm_out = nn.GroupNorm( + num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6 + ) + self.conv_act = nn.SiLU() + self.conv_out = nn.Conv2d( + block_out_channels[0], + out_channels, + 3, + padding=1, + pad_mode="pad", + has_bias=True, + ) + + def construct(self, z): + """ + Builds the decoder computation graph. + + Args: + z (Tensor): Input tensor. + + Returns: + Tensor: Decoded output tensor. + """ + sample = z + sample = self.conv_in(sample) + sample = self.mid_block(sample) + for up_block in self.up_blocks: + sample = up_block(sample) + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + return sample + + +class AutoencoderKL(nn.Cell): + r"""Variational Autoencoder (VAE) model with KL loss from the paper Auto-Encoding Variational Bayes by Diederik P. Kingma + and Max Welling. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library + implements for all the model (such as downloading or saving, etc.) + + Parameters: + in_channels (int, *optional*, defaults to 3): Number of channels in the input image. + out_channels (int, *optional*, defaults to 3): Number of channels in the output. + down_block_types (`Tuple[str]`, *optional*, defaults to : + obj:`("DownEncoderBlock2D",)`): Tuple of downsample block types. + up_block_types (`Tuple[str]`, *optional*, defaults to : + obj:`("UpDecoderBlock2D",)`): Tuple of upsample block types. + block_out_channels (`Tuple[int]`, *optional*, defaults to : + obj:`(64,)`): Tuple of block output channels. + act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. + latent_channels (`int`, *optional*, defaults to 4): Number of channels in the latent space. + sample_size (`int`, *optional*, defaults to `32`): TODO + scaling_factor (`float`, *optional*, defaults to 0.18215): + The component-wise standard deviation of the trained latent space computed using the first batch of the + training set. This is used to scale the latent space to have unit variance when training the diffusion + model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the + diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1 + / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image + """ + + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + down_block_types: Tuple[str] = ("DownEncoderBlock2D",), + up_block_types: Tuple[str] = ("UpDecoderBlock2D",), + block_out_channels: Tuple[int] = (64,), + layers_per_block: int = 1, + act_fn: str = "silu", + latent_channels: int = 4, + norm_num_groups: int = 32, + ): + super().__init__() + + # pass init params to Encoder + self.encoder = Encoder( + in_channels=in_channels, + out_channels=latent_channels, + down_block_types=down_block_types, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + act_fn=act_fn, + norm_num_groups=norm_num_groups, + double_z=True, + ) + + # pass init params to Decoder + self.decoder = Decoder( + in_channels=latent_channels, + out_channels=out_channels, + up_block_types=up_block_types, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + norm_num_groups=norm_num_groups, + act_fn=act_fn, + ) + + self.quant_conv = mint.nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1) + self.post_quant_conv = mint.nn.Conv2d(latent_channels, latent_channels, 1) + self.use_slicing = False + + def encode(self, x: ms.Tensor) -> DiagonalGaussianDistribution: + h = self.encoder(x) + moments = self.quant_conv(h) + posterior = DiagonalGaussianDistribution(moments) + return posterior + + def _decode(self, z: ms.Tensor) -> ms.Tensor: + z = self.post_quant_conv(z) + dec = self.decoder(z) + return dec + + def enable_slicing(self): + r""" + Enable sliced VAE decoding. + + When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several + steps. This is useful to save some memory and allow larger batch sizes. + """ + self.use_slicing = True + + def disable_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_slicing` was previously invoked, this method will go back to computing + decoding in one step. + """ + self.use_slicing = False + + def decode(self, z: ms.Tensor) -> ms.Tensor: + if self.use_slicing and z.shape[0] > 1: + decoded_slices = [self._decode(z_slice) for z_slice in z.split(1)] + decoded = ops.cat(decoded_slices) + else: + decoded = self._decode(z) + return decoded + + def construct( + self, + sample: ms.Tensor, + sample_posterior: bool = False, + return_posterior: bool = False, + ) -> ms.Tensor: + r""" + Args: + sample (`ms.FloatTensor`): Input sample. + sample_posterior (`bool`, *optional*, defaults to `False`): + Whether to sample from the posterior. + return_posterior (`bool`, *optional*, defaults to `False`): + Whether or not to return `posterior` along with `dec` for calculating the training loss. + """ + + posterior = self.encode(sample) + + if sample_posterior: + z = posterior.sample() + else: + z = posterior.mode() + dec = self.decode(z) + if return_posterior: + return dec, posterior + return dec diff --git a/MindEarth/applications/nowcasting/PreDiff/src/vae/resnet.py b/MindEarth/applications/nowcasting/PreDiff/src/vae/resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..6a20d3b5fdb1ad343105c33e6649708d4745810f --- /dev/null +++ b/MindEarth/applications/nowcasting/PreDiff/src/vae/resnet.py @@ -0,0 +1,907 @@ +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"resnet model" +from functools import partial +import numpy as np + +import mindspore as ms +from mindspore import Tensor, mint, nn, ops + + +class AvgPool1d(nn.Cell): + """ + 1D average pooling layer implementation with customizable kernel size, stride, and padding. + Performs spatial downsampling by computing average values over sliding windows. + """ + def __init__(self, kernel_size, stride=1, padding=0): + """ + Initialize 1D average pooling parameters with validation checks. + + Args: + kernel_size (int): Length of the pooling window + stride (int): Stride size for window movement (default=1) + padding (int): Zero-padding added to both sides of input (default=0) + + Raises: + ValueError: If kernel_size ≤ 0, stride ≤ 0, or padding < 0 + """ + super().__init__() + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + self.mean = ops.ReduceMean(keep_dims=False) + if stride <= 0: + raise ValueError("stride must be positive") + if kernel_size <= 0: + raise ValueError("kernel_size must be positive") + if padding < 0: + raise ValueError("padding must be non-negative") + + def construct(self, x): + """ + Apply 1D average pooling to input tensor. + """ + input_shape = x.shape + n, c, l_in = input_shape[0], input_shape[1], input_shape[2] + pad_left = self.padding + pad_right = self.padding + x = ops.Pad(((0, 0), (0, 0), (pad_left, pad_right)))(x) + l_in += pad_left + pad_right + l_out = (l_in - self.kernel_size) // self.stride + 1 + output = Tensor(np.zeros((n, c, l_out)), dtype=ms.float32) + for i in range(l_out): + start = i * self.stride + end = start + self.kernel_size + if end <= l_in: + window = x[:, :, start:end] + output[:, :, i] = self.mean(window, -1) + + return output + + +class Upsample1D(nn.Cell): + """ + An upsampling layer with an optional convolution. + + Parameters: + channels: channels in the inputs and outputs. + use_conv: a bool determining if a convolution is applied. + use_conv_transpose: + out_channels: + """ + + def __init__( + self, + channels, + use_conv=False, + use_conv_transpose=False, + out_channels=None, + name="conv", + ): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_conv_transpose = use_conv_transpose + self.name = name + + self.conv = None + if use_conv_transpose: + self.conv = nn.Conv1dTranspose( + channels, + self.out_channels, + kernel_size=4, + stride=2, + pad_mode="pad", + padding=1, + has_bias=True, + ) + elif use_conv: + self.conv = nn.Conv1d( + self.channels, + self.out_channels, + 3, + padding=1, + pad_mode="pad", + has_bias=True, + ) + + def construct(self, x): + """forward""" + assert x.shape[1] == self.channels + if self.use_conv_transpose: + return self.conv(x) + + x = ops.interpolate(x, scale_factor=2.0, mode="nearest") + + if self.use_conv: + x = self.conv(x) + + return x + + +class Downsample1D(nn.Cell): + """ + A downsampling layer with an optional convolution. + + Parameters: + channels: channels in the inputs and outputs. + use_conv: a bool determining if a convolution is applied. + out_channels: + padding: + """ + + def __init__( + self, channels, use_conv=False, out_channels=None, padding=1, name="conv" + ): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.padding = padding + stride = 2 + self.name = name + + if use_conv: + self.conv = nn.Conv1d( + self.channels, + self.out_channels, + 3, + stride=stride, + padding=padding, + pad_mode="pad", + has_bias=True, + ) + else: + assert self.channels == self.out_channels + self.conv = AvgPool1d(kernel_size=stride, stride=stride) + + def construct(self, x): + assert x.shape[1] == self.channels + return self.conv(x) + + +class Upsample2D(nn.Cell): + """ + An upsampling layer with an optional convolution. + + Parameters: + channels: channels in the inputs and outputs. + use_conv: a bool determining if a convolution is applied. + use_conv_transpose: + out_channels: + """ + + def __init__( + self, + channels, + use_conv=False, + use_conv_transpose=False, + out_channels=None, + name="conv", + ): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_conv_transpose = use_conv_transpose + self.name = name + + conv = None + if use_conv_transpose: + conv = nn.Conv2dTranspose( + channels, + self.out_channels, + kernel_size=4, + stride=2, + padding=1, + pad_mode="pad", + has_bias=True, + ) + elif use_conv: + conv = nn.Conv2d( + self.channels, + self.out_channels, + kernel_size=3, + padding=1, + pad_mode="pad", + has_bias=True, + ) + if name == "conv": + self.conv = conv + else: + self.conv2d_0 = conv + + def construct(self, hidden_states, output_size=None): + """forward""" + assert hidden_states.shape[1] == self.channels + + if self.use_conv_transpose: + return self.conv(hidden_states) + + dtype = hidden_states.dtype + if dtype == ms.bfloat16: + hidden_states = hidden_states.to(ms.float32) + if hidden_states.shape[0] >= 64: + hidden_states = hidden_states.contiguous() + if output_size is None: + hidden_states = ops.interpolate( + hidden_states, + scale_factor=2.0, + recompute_scale_factor=True, + mode="nearest", + ) + else: + hidden_states = ops.interpolate( + hidden_states, size=output_size, mode="nearest" + ) + + if dtype == ms.bfloat16: + hidden_states = hidden_states.to(dtype) + if self.use_conv: + if self.name == "conv": + hidden_states = self.conv(hidden_states) + else: + hidden_states = self.conv2d_0(hidden_states) + + return hidden_states + + +class Downsample2D(nn.Cell): + """ + A downsampling layer with an optional convolution. + + Parameters: + channels: channels in the inputs and outputs. + use_conv: a bool determining if a convolution is applied. + out_channels: + padding: + """ + + def __init__( + self, channels, use_conv=False, out_channels=None, padding=1, name="conv" + ): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.padding = padding + stride = 2 + self.name = name + + if use_conv: + conv = nn.Conv2d( + self.channels, + self.out_channels, + kernel_size=3, + stride=stride, + padding=padding, + pad_mode="pad", + has_bias=True, + ) + else: + assert self.channels == self.out_channels + conv = mint.nn.AvgPool2d(kernel_size=stride, stride=stride) + if name == "conv": + self.conv2d_0 = conv + self.conv = conv + elif name == "Conv2d_0": + self.conv = conv + else: + self.conv = conv + + def construct(self, hidden_states): + """forward""" + assert hidden_states.shape[1] == self.channels + if self.use_conv and self.padding == 0: + pad = (0, 1, 0, 1) + hidden_states = ops.pad(hidden_states, pad, mode="constant", value=None) + + assert hidden_states.shape[1] == self.channels + hidden_states = self.conv(hidden_states) + + return hidden_states + + +class FirUpsample2D(nn.Cell): + """ + 2D upsampling layer with optional FIR filtering and convolutional projection. + Implements pixel-shuffle based upsampling with optional convolutional transformation. + """ + def __init__( + self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1) + ): + """ + Initialize upsample layer parameters. + + Args: + channels (int): Number of input channels + out_channels (int): Number of output channels (defaults to input channels if not specified) + use_conv (bool): Whether to apply 3x3 convolution after upsampling + fir_kernel (tuple): FIR filter kernel coefficients for antialiasing + + Raises: + ValueError: If invalid kernel parameters are provided + """ + super().__init__() + out_channels = out_channels if out_channels else channels + if use_conv: + self.conv2d_0 = nn.Conv2d( + channels, + out_channels, + kernel_size=3, + stride=1, + padding=1, + pad_mode="pad", + has_bias=True, + ) + self.use_conv = use_conv + self.fir_kernel = fir_kernel + self.out_channels = out_channels + + def _upsample_2d(self, hidden_states, weight=None, kernel=None, factor=2, gain=1): + """ + Core upsampling operation with optional convolution and FIR filtering. + """ + + assert isinstance(factor, int) and factor >= 1 + + # Setup filter kernel. + if kernel is None: + kernel = [1] * factor + + # setup kernel + kernel = Tensor(kernel, dtype=ms.float32) + if kernel.ndim == 1: + kernel = ops.outer(kernel, kernel) + kernel /= ops.sum(kernel) + + kernel = kernel * (gain * (factor**2)) + + if self.use_conv: + convh = weight.shape[2] + convw = weight.shape[3] + in_c = weight.shape[1] + + pad_value = (kernel.shape[0] - factor) - (convw - 1) + + stride = (factor, factor) + # Determine data dimensions. + output_shape = ( + (hidden_states.shape[2] - 1) * factor + convh, + (hidden_states.shape[3] - 1) * factor + convw, + ) + output_padding = ( + output_shape[0] - (hidden_states.shape[2] - 1) * stride[0] - convh, + output_shape[1] - (hidden_states.shape[3] - 1) * stride[1] - convw, + ) + assert output_padding[0] >= 0 and output_padding[1] >= 0 + num_groups = hidden_states.shape[1] // in_c + + # Transpose weights. + weight = ops.reshape(weight, (num_groups, -1, in_c, convh, convw)) + weight = ops.flip(weight, dims=[3, 4]).permute(0, 2, 1, 3, 4) + weight = ops.reshape(weight, (num_groups * in_c, -1, convh, convw)) + conv_transpose2d = nn.Conv2dTranspose( + weight[0], + weight[1], + (weight[2], weight[3]), + stride=stride, + output_padding=output_padding, + padding=0, + pad_mode="pad", + ) + inverse_conv = conv_transpose2d(hidden_states) + + output = upfirdn2d_native( + inverse_conv, + ms.tensor(kernel), + pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2 + 1), + ) + else: + pad_value = kernel.shape[0] - factor + output = upfirdn2d_native( + hidden_states, + ms.tensor( + kernel, + ), + up=factor, + pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2), + ) + + return output + + def construct(self, hidden_states): + """ + Apply upsampling transformation with optional convolutional projection. + """ + if self.use_conv: + height = self._upsample_2d( + hidden_states, self.conv2d_0.weight, kernel=self.fir_kernel + ) + height = height + self.conv2d_0.bias.reshape(1, -1, 1, 1) + else: + height = self._upsample_2d(hidden_states, kernel=self.fir_kernel, factor=2) + + return height + + +class FirDownsample2D(nn.Cell): + """ + 2D downsampling layer with optional FIR filtering and convolutional projection. + Implements anti-aliased downsampling with optional 3x3 convolution. + """ + def __init__( + self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1) + ): + """ + Initialize downsampling layer parameters. + Args: + channels (int): Number of input channels + out_channels (int): Number of output channels (defaults to input channels if not specified) + use_conv (bool): Whether to apply 3x3 convolution before downsampling + fir_kernel (tuple): FIR filter kernel coefficients for antialiasing + + Raises: + ValueError: If invalid kernel parameters are provided + """ + super().__init__() + out_channels = out_channels if out_channels else channels + if use_conv: + self.conv2d_0 = nn.Conv2d( + channels, + out_channels, + kernel_size=3, + stride=1, + padding=1, + pad_mode="pad", + has_bias=True, + ) + self.fir_kernel = fir_kernel + self.use_conv = use_conv + self.out_channels = out_channels + + def _downsample_2d(self, hidden_states, weight=None, kernel=None, factor=2, gain=1): + """ + Core downsampling operation with optional convolution and FIR filtering. + """ + assert isinstance(factor, int) and factor >= 1 + if kernel is None: + kernel = [1] * factor + + # setup kernel + kernel = ms.tensor(kernel, dtype=ms.float32) + if kernel.ndim == 1: + kernel = ms.outer(kernel, kernel) + kernel /= ms.sum(kernel) + + kernel = kernel * gain + + if self.use_conv: + _, _, _, convw = weight.shape + pad_value = (kernel.shape[0] - factor) + (convw - 1) + stride_value = [factor, factor] + upfirdn_input = upfirdn2d_native( + hidden_states, + ms.tensor(kernel), + pad=((pad_value + 1) // 2, pad_value // 2), + ) + output = ops.conv2d(upfirdn_input, weight, stride=stride_value, padding=0) + else: + pad_value = kernel.shape[0] - factor + output = upfirdn2d_native( + hidden_states, + ms.tensor(kernel), + down=factor, + pad=((pad_value + 1) // 2, pad_value // 2), + ) + + return output + + def construct(self, hidden_states): + """ + Apply downsampling transformation with optional convolutional projection. + """ + if self.use_conv: + downsample_input = self._downsample_2d( + hidden_states, weight=self.conv2d_0.weight, kernel=self.fir_kernel + ) + hidden_states = downsample_input + self.conv2d_0.bias.reshape(1, -1, 1, 1) + else: + hidden_states = self._downsample_2d( + hidden_states, kernel=self.fir_kernel, factor=2 + ) + + return hidden_states + + +class ResnetBlock2D(nn.Cell): + """ + 2D ResNet block with optional time embeddings and spatial transformations. + Implements pre-activation residual connections with optional upsampling/downsampling. + """ + def __init__( + self, + *, + in_channels, + out_channels=None, + conv_shortcut=False, + dropout=0.0, + temb_channels=512, + groups=32, + groups_out=None, + pre_norm=True, + eps=1e-6, + non_linearity="swish", + time_embedding_norm="default", + kernel=None, + output_scale_factor=1.0, + use_in_shortcut=None, + up=False, + down=False, + ): + """ + Initialize ResNet block with configurable normalization and spatial transformations. + + Args: + in_channels (int): Number of input channels + out_channels (int): Number of output channels (defaults to in_channels) + conv_shortcut (bool): Use 1x1 convolution for shortcut connection + dropout (float): Dropout probability (default=0) + temb_channels (int): Time embedding dimension (default=512) + groups (int): Number of groups for group normalization + groups_out (int): Groups for second normalization layer (defaults to groups) + pre_norm (bool): Apply normalization before non-linearity + eps (float): Epsilon for numerical stability in normalization + non_linearity (str): Activation function type ("swish", "mish", "silu") + time_embedding_norm (str): Time embedding normalization mode ("default" or "scale_shift") + kernel (str): Upsample/downsample kernel type ("fir", "sde_vp") + output_scale_factor (float): Output scaling factor (default=1.0) + use_in_shortcut (bool): Force shortcut connection usage + up (bool): Enable upsampling transformation + down (bool): Enable downsampling transformation + + Raises: + ValueError: If invalid non_linearity or time_embedding_norm values are provided + """ + super().__init__() + self.pre_norm = pre_norm + self.pre_norm = True + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + self.time_embedding_norm = time_embedding_norm + self.up = up + self.down = down + self.output_scale_factor = output_scale_factor + self.groups = groups + self.in_channels = in_channels + self.eps = eps + if groups_out is None: + groups_out = groups + + self.norm1 = nn.GroupNorm( + num_groups=groups, num_channels=in_channels, eps=eps, affine=True + ) + + self.conv1 = nn.Conv2d( + in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1, + pad_mode="pad", + has_bias=True, + ) + + if temb_channels is not None: + if self.time_embedding_norm == "default": + time_emb_proj_out_channels = out_channels + elif self.time_embedding_norm == "scale_shift": + time_emb_proj_out_channels = out_channels * 2 + else: + raise ValueError( + f"unknown time_embedding_norm : {self.time_embedding_norm} " + ) + + self.time_emb_proj = nn.Dense(temb_channels, time_emb_proj_out_channels) + else: + self.time_emb_proj = None + + self.norm2 = nn.GroupNorm( + num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True + ) + self.dropout = nn.Dropout(p=dropout) + self.conv2 = mint.nn.Conv2d( + out_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + if non_linearity == "swish": + self.nonlinearity = ops.silu() + elif non_linearity == "mish": + self.nonlinearity = Mish() + elif non_linearity == "silu": + self.nonlinearity = nn.SiLU() + + self.upsample = self.downsample = None + if self.up: + if kernel == "fir": + fir_kernel = (1, 3, 3, 1) + self.upsample = lambda x: upsample_2d(x, kernel=fir_kernel) + elif kernel == "sde_vp": + self.upsample = partial( + ops.interpolate, scale_factor=2.0, mode="nearest" + ) + else: + self.upsample = Upsample2D(in_channels, use_conv=False) + elif self.down: + if kernel == "fir": + fir_kernel = (1, 3, 3, 1) + self.downsample = lambda x: downsample_2d(x, kernel=fir_kernel) + elif kernel == "sde_vp": + self.downsample = partial(mint.nn.AvgPool2d, kernel_size=2, stride=2) + else: + self.downsample = Downsample2D( + in_channels, use_conv=False, padding=1, name="op" + ) + + self.use_in_shortcut = ( + self.in_channels != self.out_channels + if use_in_shortcut is None + else use_in_shortcut + ) + + self.conv_shortcut = None + if self.use_in_shortcut: + self.conv_shortcut = nn.Conv2d( + in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0, + pad_mode="pad", + has_bias=True, + ) + + def construct(self, input_tensor, temb): + """ + Forward pass of the ResNet block. + + Args: + input_tensor (Tensor): Input tensor of shape (batch, channels, height, width). + temb (Tensor): Optional time embedding tensor. + + Returns: + Tensor: Output tensor after applying residual block operations. + """ + hidden_states = input_tensor + hidden_states = self.norm1(hidden_states) + hidden_states = self.nonlinearity(hidden_states) + if self.upsample is not None: + input_tensor = self.upsample(input_tensor) + hidden_states = self.upsample(hidden_states) + + elif self.downsample is not None: + input_tensor = self.downsample(input_tensor) + hidden_states = self.downsample(hidden_states) + hidden_states = self.conv1(hidden_states) + if temb is not None: + temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None] + if temb is not None and self.time_embedding_norm == "default": + hidden_states = hidden_states + temb + hidden_states = self.norm2(hidden_states) + if temb is not None and self.time_embedding_norm == "scale_shift": + scale, shift = ops.chunk(temb, 2, axis=1) + hidden_states = hidden_states * (1 + scale) + shift + hidden_states = self.nonlinearity(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.conv2(hidden_states) + if self.conv_shortcut is not None: + input_tensor = self.conv_shortcut(input_tensor) + output_tensor = (input_tensor + hidden_states) / self.output_scale_factor + return output_tensor + + +class Mish(nn.Cell): + """Implements the Mish activation function: x * tanh(softplus(x)).""" + def __init__(self): + super().__init__() + self.tanh = ops.Tanh() + self.softplus = ops.Softplus() + + def construct(self, hidden_states): + """Compute Mish activation on input tensor.""" + return hidden_states * self.tanh(self.softplus(hidden_states)) + + +def rearrange_dims(tensor): + """ + Adjust tensor dimensions based on input shape: + - 2D → add two singleton dimensions + - 3D → add one singleton dimension + - 4D → squeeze spatial dimensions + + Args: + tensor (Tensor): Input tensor. + + Returns: + Tensor: Dimension-adjusted tensor. + + Raises: + ValueError: If input tensor has invalid dimensions. + """ + if len(tensor.shape) == 2: + return tensor[:, :, None] + if len(tensor.shape) == 3: + return tensor[:, :, None, :] + if len(tensor.shape) == 4: + return tensor[:, :, 0, :] + raise ValueError(f"`len(tensor)`: {len(tensor)} has to be 2, 3 or 4.") + + +class Conv1dBlock(nn.Cell): + """ + 1D Convolution block with GroupNorm and Mish activation. + + Args: + inp_channels (int): Number of input channels. + out_channels (int): Number of output channels. + kernel_size (int): Convolution kernel size. + n_groups (int): Number of groups for GroupNorm. Defaults to 8. + """ + def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8): + super().__init__() + + self.conv1d = nn.Conv1d( + inp_channels, + out_channels, + kernel_size, + padding=kernel_size // 2, + has_bias=True, + pad_mode="valid", + ) + self.group_norm = nn.GroupNorm(n_groups, out_channels) + self.mish = ops.mish() + + def construct(self, x): + """Apply convolution, normalization, dimension rearrangement and activation.""" + x = self.conv1d(x) + x = rearrange_dims(x) + x = self.group_norm(x) + x = rearrange_dims(x) + x = self.mish(x) + return x + + +class ResidualTemporalBlock1D(nn.Cell): + """ResidualTemporalBlock1D""" + def __init__(self, inp_channels, out_channels, embed_dim, kernel_size=5): + super().__init__() + self.conv_in = Conv1dBlock(inp_channels, out_channels, kernel_size) + self.conv_out = Conv1dBlock(out_channels, out_channels, kernel_size) + + self.time_emb_act = nn.Mish() + self.time_emb = nn.Linear(embed_dim, out_channels) + + self.residual_conv = ( + nn.Conv1d(inp_channels, out_channels, 1, has_bias=True, pad_mode="valid") + if inp_channels != out_channels + else nn.Identity() + ) + + def construct(self, x, t): + """ + Args: + x : [ batch_size x inp_channels x horizon ] + t : [ batch_size x embed_dim ] + + returns: + out : [ batch_size x out_channels x horizon ] + """ + t = self.time_emb_act(t) + t = self.time_emb(t) + out = self.conv_in(x) + rearrange_dims(t) + out = self.conv_out(out) + return out + self.residual_conv(x) + + +def upsample_2d(hidden_states, kernel=None, factor=2, gain=1): + """Upsample2D a batch of 2D images with the given filter.""" + assert isinstance(factor, int) and factor >= 1 + if kernel is None: + kernel = [1] * factor + + kernel = ms.tensor(kernel, dtype=ms.float32) + if kernel.ndim == 1: + kernel = ms.outer(kernel, kernel) + kernel /= ms.sum(kernel) + + kernel = kernel * (gain * (factor**2)) + pad_value = kernel.shape[0] - factor + output = upfirdn2d_native( + hidden_states, + up=factor, + pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2), + ) + return output + + +def downsample_2d(hidden_states, kernel=None, factor=2, gain=1): + """Downsample2D a batch of 2D images with the given filter.""" + assert isinstance(factor, int) and factor >= 1 + if kernel is None: + kernel = [1] * factor + + kernel = ms.tensor(kernel, dtype=ms.float32) + if kernel.ndim == 1: + kernel = ms.outer(kernel, kernel) + kernel /= ms.sum(kernel) + + kernel = kernel * gain + pad_value = kernel.shape[0] - factor + output = upfirdn2d_native( + hidden_states, + down=factor, + pad=((pad_value + 1) // 2, pad_value // 2), + ) + return output + + +def upfirdn2d_native(tensor, kernel=None, up=1, down=1, pad=(0, 0)): + """upfirdn2d native""" + up_x = up_y = up + down_x = down_y = down + pad_x0 = pad_y0 = pad[0] + pad_x1 = pad_y1 = pad[1] + + _, channel, in_h, in_w = tensor.shape + tensor = tensor.reshape(-1, in_h, in_w, 1) + + _, in_h, in_w, minor = tensor.shape + kernel_h, kernel_w = kernel.shape + + out = tensor.view(-1, in_h, 1, in_w, 1, minor) + out = ops.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) + out = out.view(-1, in_h * up_y, in_w * up_x, minor) + + out = ops.pad( + out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)] + ) + out = out[ + :, + max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0), + max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0), + :, + ] + + out = out.permute(0, 3, 1, 2) + out = out.reshape( + [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1] + ) + w = ms.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) + out = ops.conv2d(out, w) + out = out.reshape( + -1, + minor, + in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, + in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, + ) + out = out.permute(0, 2, 3, 1) + out = out[:, ::down_y, ::down_x, :] + + out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 + out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 + + return out.view(-1, channel, out_h, out_w) diff --git a/MindEarth/applications/nowcasting/PreDiff/src/vae/unet_2d_blocks.py b/MindEarth/applications/nowcasting/PreDiff/src/vae/unet_2d_blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..4f98b3d5ecc6e9841f7e63b66cdb0ee5d8fae01e --- /dev/null +++ b/MindEarth/applications/nowcasting/PreDiff/src/vae/unet_2d_blocks.py @@ -0,0 +1,508 @@ +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"UNetMidBlock2D" +import math +from typing import Optional + +import mindspore as ms +from mindspore import nn, ops + +from .resnet import Downsample2D, ResnetBlock2D, Upsample2D + + +def get_down_block( + down_block_type, + num_layers, + in_channels, + out_channels, + temb_channels, + add_downsample, + resnet_eps, + resnet_act_fn, + attn_num_head_channels, + resnet_groups=None, + cross_attention_dim=None, + downsample_padding=None, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + resnet_time_scale_shift="default", +): + """set down_block""" + down_block_type = ( + down_block_type[7:] + if down_block_type.startswith("UNetRes") + else down_block_type + ) + if down_block_type == "DownEncoderBlock2D": + return DownEncoderBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + raise ValueError(f"{down_block_type} does not exist.") + + +def get_up_block( + up_block_type, + num_layers, + in_channels, + out_channels, + prev_output_channel, + temb_channels, + add_upsample, + resnet_eps, + resnet_act_fn, + attn_num_head_channels, + resnet_groups=None, + cross_attention_dim=None, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + resnet_time_scale_shift="default", +): + """set up_block""" + up_block_type = ( + up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type + ) + if up_block_type == "UpDecoderBlock2D": + return UpDecoderBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + raise ValueError(f"{up_block_type} does not exist.") + + +class AttentionBlock(nn.Cell): + """ + An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted + to the N-d case. Uses three q, k, v linear layers to compute attention. + + Args: + channels (int): The number of channels in the input and output. + num_head_channels (int, optional): The number of channels in each attention head. If None, uses a single head. + norm_num_groups (int, optional): Number of groups for group normalization (default: 32). + rescale_output_factor (float, optional): Factor to rescale the output (default: 1.0). + eps (float, optional): Epsilon value for group normalization (default: 1e-5). + + Attributes: + num_heads (int): Calculated number of attention heads based on `num_head_channels`. + group_norm (nn.GroupNorm): Group normalization layer. + query/key/value (nn.Dense): Linear layers for query, key, and value projections. + proj_attn (nn.Dense): Final projection layer after attention computation. + """ + + def __init__( + self, + channels: int, + num_head_channels: Optional[int] = None, + norm_num_groups: int = 32, + rescale_output_factor: float = 1.0, + eps: float = 1e-5, + ): + super().__init__() + self.channels = channels + + self.num_heads = ( + channels // num_head_channels if num_head_channels is not None else 1 + ) + self.num_head_size = num_head_channels + self.group_norm = nn.GroupNorm( + num_channels=channels, num_groups=norm_num_groups, eps=eps, affine=True + ) + + # define q,k,v as linear layers + self.query = nn.Dense(channels, channels) + self.key = nn.Dense(channels, channels) + self.value = nn.Dense(channels, channels) + + self.rescale_output_factor = rescale_output_factor + self.proj_attn = nn.Dense(channels, channels, 1) + + self._use_memory_efficient_attention_xformers = False + self._attention_op = None + self.softmax_op = ops.Softmax(axis=-1) + + def reshape_heads_to_batch_dim(self, tensor): + """ + Reshape tensor to split attention heads into batch dimension for efficient computation." + """ + batch_size, seq_in, dim = tensor.shape + head_size = self.num_heads + tensor = tensor.reshape(batch_size, seq_in, head_size, dim // head_size) + tensor = tensor.permute(0, 2, 1, 3).reshape( + batch_size * head_size, seq_in, dim // head_size + ) + return tensor + + def reshape_batch_dim_to_heads(self, tensor): + """ + Reverse reshape_heads_to_batch_dim to merge batch dimension back into heads." + """ + batch_size, seq_in, dim = tensor.shape + head_size = self.num_heads + tensor = tensor.reshape(batch_size // head_size, head_size, seq_in, dim) + tensor = tensor.permute(0, 2, 1, 3).reshape( + batch_size // head_size, seq_in, dim * head_size + ) + return tensor + + def construct(self, hidden_states): + """Compute multi-head self-attention.""" + residual = hidden_states + batch, channel, height, width = hidden_states.shape + hidden_states = self.group_norm(hidden_states) + hidden_states = hidden_states.view(batch, channel, height * width).swapaxes( + 1, 2 + ) + query_proj = self.query(hidden_states) + key_proj = self.key(hidden_states) + value_proj = self.value(hidden_states) + + scale = 1 / math.sqrt(self.channels / self.num_heads) + + query_proj = self.reshape_heads_to_batch_dim(query_proj) + key_proj = self.reshape_heads_to_batch_dim(key_proj) + value_proj = self.reshape_heads_to_batch_dim(value_proj) + + shape = (query_proj.shape[0], query_proj.shape[1], key_proj.shape[1]) + uninitialized_tensor = ms.numpy.empty(shape=shape, dtype=query_proj.dtype) + attention_scores = ops.baddbmm( + uninitialized_tensor, + query_proj, + key_proj.swapaxes(-1, -2), + beta=0, + alpha=scale, + ) + attention_probs = self.softmax_op(attention_scores.astype(ms.float32)).type( + attention_scores.dtype + ) + + hidden_states = ops.bmm(attention_probs, value_proj) + hidden_states = self.reshape_batch_dim_to_heads(hidden_states) + hidden_states = self.proj_attn(hidden_states) + + hidden_states = hidden_states.swapaxes(-1, -2).reshape( + batch, channel, height, width + ) + hidden_states = (hidden_states + residual) / self.rescale_output_factor + return hidden_states + + +class UNetMidBlock2D(nn.Cell): + """ + UNet middle block for 2D architectures. + """ + def __init__( + self, + in_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + add_attention: bool = True, + attn_num_head_channels=1, + output_scale_factor=1.0, + ): + """ + UNet middle block for 2D architectures. Contains residual blocks and optional attention layers. + + Args: + in_channels (int): Number of input channels. + temb_channels (int): Number of time embedding channels. + dropout (float): Dropout probability (default: 0.0). + num_layers (int): Number of residual blocks (default: 1). + resnet_eps (float): Epsilon for ResNet normalization (default: 1e-6). + resnet_time_scale_shift (str): Time scale shift method for ResNet ("default" or "scale_shift"). + resnet_act_fn (str): Activation function for ResNet layers (default: "swish"). + resnet_groups (int): Number of groups for group normalization in ResNet. + resnet_pre_norm (bool): Whether to use pre-normalization in ResNet. + add_attention (bool): Whether to include attention blocks (default: True). + attn_num_head_channels (int): Number of channels per attention head. + output_scale_factor (float): Scaling factor for output (default: 1.0). + + Attributes: + resnets (nn.CellList): List of ResNet blocks. + attentions (nn.CellList): List of attention blocks (or None if disabled). + """ + super().__init__() + resnet_groups = ( + resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + ) + self.add_attention = add_attention + self.num_layers = num_layers + # there is always at least one resnet + resnets = [ + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ] + attentions = [] + + for _ in range(self.num_layers): + if self.add_attention: + attentions.append( + AttentionBlock( + in_channels, + num_head_channels=attn_num_head_channels, + rescale_output_factor=output_scale_factor, + eps=resnet_eps, + norm_num_groups=resnet_groups, + ) + ) + else: + attentions.append(None) + + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.attentions = nn.CellList(attentions) + self.resnets = nn.CellList(resnets) + + def construct(self, hidden_states, temb=None): + """ + Forward pass through the middle block. + + Args: + hidden_states (Tensor): Input tensor. + temb (Tensor, optional): Time embedding tensor. + + Returns: + Tensor: Output tensor after processing through all blocks. + """ + + hidden_states = self.resnets[0](hidden_states, temb) + for attn, resnet in zip(self.attentions, self.resnets[1:]): + if attn is not None: + hidden_states = attn(hidden_states) + hidden_states = resnet(hidden_states, temb) + + return hidden_states + + +class DownEncoderBlock2D(nn.Cell): + """ + Downsample block for encoder part of UNet. + """ + def __init__( + self, + in_channels: int, + out_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor=1.0, + add_downsample=True, + downsample_padding=1, + ): + """ + Downsample block for encoder part of UNet. Contains residual blocks and optional downsampling. + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + dropout (float): Dropout probability (default: 0.0). + num_layers (int): Number of residual blocks (default: 1). + resnet_eps (float): Epsilon for ResNet normalization (default: 1e-6). + resnet_time_scale_shift (str): Time scale shift method for ResNet ("default" or "scale_shift"). + resnet_act_fn (str): Activation function for ResNet layers (default: "swish"). + resnet_groups (int): Number of groups for group normalization in ResNet. + resnet_pre_norm (bool): Whether to use pre-normalization in ResNet. + output_scale_factor (float): Scaling factor for output (default: 1.0). + add_downsample (bool): Whether to include downsampling layer (default: True). + downsample_padding (int): Padding for downsampling convolution (default: 1). + + Attributes: + resnets (nn.CellList): List of ResNet blocks. + downsamplers (nn.CellList or None): Downsampling layer if enabled. + """ + super().__init__() + resnets = [] + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=None, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.resnets = nn.CellList(resnets) + + if add_downsample: + self.downsamplers = nn.CellList( + [ + Downsample2D( + out_channels, + use_conv=True, + out_channels=out_channels, + padding=downsample_padding, + name="op", + ) + ] + ) + else: + self.downsamplers = None + + def construct(self, hidden_states): + """ + Forward pass through the downsample block. + + Args: + hidden_states (Tensor): Input tensor. + + Returns: + Tensor: Output tensor after residual blocks and optional downsampling. + """ + for resnet in self.resnets: + hidden_states = resnet(hidden_states, temb=None) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + return hidden_states + + +class UpDecoderBlock2D(nn.Cell): + """ + Upsample block for decoder part of UNet. + """ + def __init__( + self, + in_channels: int, + out_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor=1.0, + add_upsample=True, + ): + """ + Upsample block for decoder part of UNet. Contains residual blocks and optional upsampling. + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + dropout (float): Dropout probability (default: 0.0). + num_layers (int): Number of residual blocks (default: 1). + resnet_eps (float): Epsilon for ResNet normalization (default: 1e-6). + resnet_time_scale_shift (str): Time scale shift method for ResNet ("default" or "scale_shift"). + resnet_act_fn (str): Activation function for ResNet layers (default: "swish"). + resnet_groups (int): Number of groups for group normalization in ResNet. + resnet_pre_norm (bool): Whether to use pre-normalization in ResNet. + output_scale_factor (float): Scaling factor for output (default: 1.0). + add_upsample (bool): Whether to include upsampling layer (default: True). + + Attributes: + resnets (nn.CellList): List of ResNet blocks. + upsamplers (nn.CellList or None): Upsampling layer if enabled. + """ + super().__init__() + resnets = [] + + for i in range(num_layers): + input_channels = in_channels if i == 0 else out_channels + + resnets.append( + ResnetBlock2D( + in_channels=input_channels, + out_channels=out_channels, + temb_channels=None, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.resnets = nn.CellList(resnets) + + if add_upsample: + self.upsamplers = nn.CellList( + [Upsample2D(out_channels, use_conv=True, out_channels=out_channels)] + ) + else: + self.upsamplers = None + + def construct(self, hidden_states): + """Forward pass through the upsample block.""" + for resnet in self.resnets: + hidden_states = resnet(hidden_states, temb=None) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states) + + return hidden_states diff --git a/MindEarth/applications/nowcasting/PreDiff/src/visual.py b/MindEarth/applications/nowcasting/PreDiff/src/visual.py new file mode 100644 index 0000000000000000000000000000000000000000..343f3e17edfa1c455ba6e2cf4ee90f12a0adc8ef --- /dev/null +++ b/MindEarth/applications/nowcasting/PreDiff/src/visual.py @@ -0,0 +1,203 @@ +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"visualization" +import math +from copy import deepcopy +from typing import Optional, Sequence, Union, Dict +from matplotlib import pyplot as plt +from matplotlib.colors import ListedColormap, BoundaryNorm +from matplotlib.font_manager import FontProperties +from matplotlib.patches import Patch +import numpy as np + + +VIL_COLORS = [ + [0, 0, 0], + [0.30196078431372547, 0.30196078431372547, 0.30196078431372547], + [0.1568627450980392, 0.7450980392156863, 0.1568627450980392], + [0.09803921568627451, 0.5882352941176471, 0.09803921568627451], + [0.0392156862745098, 0.4117647058823529, 0.0392156862745098], + [0.0392156862745098, 0.29411764705882354, 0.0392156862745098], + [0.9607843137254902, 0.9607843137254902, 0.0], + [0.9294117647058824, 0.6745098039215687, 0.0], + [0.9411764705882353, 0.43137254901960786, 0.0], + [0.6274509803921569, 0.0, 0.0], + [0.9058823529411765, 0.0, 1.0], +] + +VIL_LEVELS = [0.0, 16.0, 31.0, 59.0, 74.0, 100.0, 133.0, 160.0, 181.0, 219.0, 255.0] + + +def vil_cmap(): + """ + Generate a ListedColormap and normalization for VIL (Vertically Integrated Liquid) visualization. + + This function creates a colormap with specific color levels for VIL data visualization. It sets under/over colors + for values outside the defined levels and handles invalid (NaN) values. + + Returns: + tuple: A tuple containing: + - cmap (ListedColormap): Colormap object with defined colors. + - norm (BoundaryNorm): Normalization object based on VIL levels. + - vmin (None): Minimum value for colormap (set to None). + - vmax (None): Maximum value for colormap (set to None). + """ + cols = deepcopy(VIL_COLORS) + lev = deepcopy(VIL_LEVELS) + nil = cols.pop(0) + under = cols[0] + over = cols[-1] + cmap = ListedColormap(cols) + cmap.set_bad(nil) + cmap.set_under(under) + cmap.set_over(over) + norm = BoundaryNorm(lev, cmap.N) + vmin, vmax = None, None + return cmap, norm, vmin, vmax + + +def vis_sevir_seq( + save_path, + seq: Union[np.ndarray, Sequence[np.ndarray]], + label: Union[str, Sequence[str]] = "pred", + norm: Optional[Dict[str, float]] = None, + interval_real_time: float = 10.0, + plot_stride=2, + label_rotation=0, + label_offset=(-0.06, 0.4), + label_avg_int=False, + fs=10, + max_cols=10, +): + """Visualize SEVIR sequence data as a grid of images with colormap and annotations. + Args: + save_path (str): Path to save the output visualization figure. + seq (Union[np.ndarray, Sequence[np.ndarray]]): Input data sequence(s) to visualize. + Can be a single array or list of arrays. + label (Union[str, Sequence[str]], optional): Labels for each sequence. Defaults to "pred". + norm (Optional[Dict[str, float]], optional): Normalization parameters (scale/shift). + Defaults to {"scale": 255, "shift": 0}. + interval_real_time (float, optional): Time interval between frames in real time. Defaults to 10.0. + plot_stride (int, optional): Stride for subsampling frames. Defaults to 2. + label_rotation (int, optional): Rotation angle for y-axis labels. Defaults to 0. + label_offset (tuple, optional): Offset for y-axis label position. Defaults to (-0.06, 0.4). + label_avg_int (bool, optional): Append average intensity to labels. Defaults to False. + fs (int, optional): Font size for text elements. Defaults to 10. + max_cols (int, optional): Maximum number of columns per row. Defaults to 10. + + Raises: + NotImplementedError: If input sequence type is not supported. + + Returns: + None: Saves visualization to disk and closes the figure. + """ + def cmap_dict(): + return { + "cmap": vil_cmap()[0], + "norm": vil_cmap()[1], + "vmin": vil_cmap()[2], + "vmax": vil_cmap()[3], + } + + fontproperties = FontProperties() + fontproperties.set_family("serif") + fontproperties.set_size(fs) + + if isinstance(seq, Sequence): + seq_list = [ele.astype(np.float32) for ele in seq] + assert isinstance(label, Sequence) and len(label) == len(seq) + label_list = label + elif isinstance(seq, np.ndarray): + seq_list = [ + seq.astype(np.float32), + ] + assert isinstance(label, str) + label_list = [ + label, + ] + else: + raise NotImplementedError + if label_avg_int: + label_list = [ + f"{ele1}\nAvgInt = {np.mean(ele2): .3f}" + for ele1, ele2 in zip(label_list, seq_list) + ] + seq_list = [ele[::plot_stride, ...] for ele in seq_list] + seq_in_list = [len(ele) for ele in seq_list] + max_len = max(seq_in_list) + max_len = min(max_len, max_cols) + seq_list_wrap = [] + label_list_wrap = [] + seq_in_list_wrap = [] + for i, (processed_seq, processed_label, seq_in) in enumerate(zip(seq_list, label_list, seq_in_list)): + num_row = math.ceil(seq_in / max_len) + for j in range(num_row): + slice_end = min(seq_in, (j + 1) * max_len) + seq_list_wrap.append(processed_seq[j * max_len : slice_end]) + if j == 0: + label_list_wrap.append(processed_label) + else: + label_list_wrap.append("") + seq_in_list_wrap.append(min(seq_in - j * max_len, max_len)) + + if norm is None: + norm = {"scale": 255, "shift": 0} + nrows = len(seq_list_wrap) + fig, ax = plt.subplots(nrows=nrows, ncols=max_len, figsize=(3 * max_len, 3 * nrows)) + + for i, (processed_seq, processed_label, seq_in) in enumerate( + zip(seq_list_wrap, label_list_wrap, seq_in_list_wrap) + ): + ax[i][0].set_ylabel( + ylabel=processed_label, fontproperties=fontproperties, rotation=label_rotation + ) + ax[i][0].yaxis.set_label_coords(label_offset[0], label_offset[1]) + for j in range(0, max_len): + if j < seq_in: + x = processed_seq[j] * norm["scale"] + norm["shift"] + ax[i][j].imshow(x, **cmap_dict()) + if i == len(seq_list) - 1 and i > 0: + ax[-1][j].set_title( + f"Min {int(interval_real_time * (j + 1) * plot_stride)}", + y=-0.25, + fontproperties=fontproperties, + ) + else: + ax[i][j].axis("off") + + for i in range(len(ax)): + for j in range(len(ax[i])): + ax[i][j].xaxis.set_ticks([]) + ax[i][j].yaxis.set_ticks([]) + + num_thresh_legend = len(VIL_LEVELS) - 1 + legend_elements = [ + Patch( + facecolor=VIL_COLORS[i], + label=f"{int(VIL_LEVELS[i - 1])}-{int(VIL_LEVELS[i])}", + ) + for i in range(1, num_thresh_legend + 1) + ] + ax[0][0].legend( + handles=legend_elements, + loc="center left", + bbox_to_anchor=(-1.2, -0.0), + borderaxespad=0, + frameon=False, + fontsize="10", + ) + plt.subplots_adjust(hspace=0.05, wspace=0.05) + plt.savefig(save_path) + plt.close(fig) diff --git a/MindEarth/applications/nowcasting/dgmr/src/utils.py b/MindEarth/applications/nowcasting/dgmr/src/utils.py index d65e1102cc6be4d3e13842830c65db63329858ff..7df841facdc75c48188beda91e0973df9275ec9e 100644 --- a/MindEarth/applications/nowcasting/dgmr/src/utils.py +++ b/MindEarth/applications/nowcasting/dgmr/src/utils.py @@ -17,11 +17,6 @@ import os import numpy as np import matplotlib.pyplot as plt - -import mindspore.nn.probability.distribution as msd -from mindspore import ops - - from mindspore import context from mindspore.train.serialization import load_checkpoint, load_param_into_net import mindspore.communication.management as D @@ -50,12 +45,6 @@ def init_data_parallel(use_ascend): def init_model(config): r"""init model.""" model_params = config["model"] - - net = msd.Normal(0.0, 1.0, seed=42) - z = net.sample((8, 8, 8, 1)) - z = ops.transpose(z, (3, 0, 1, 2)) - print(z.shape) - g_model = DgmrGenerator( forecast_steps=model_params["forecast_steps"], in_channels=model_params["in_channels"], @@ -64,7 +53,6 @@ def init_model(config): latent_channels=model_params["latent_channels"], context_channels=model_params["context_channels"], generation_steps=model_params["generation_steps"], - z=z ) d_model = DgmrDiscriminator( in_channels=model_params["in_channels"], diff --git a/MindEarth/mindearth/cell/demnet/demnet.py b/MindEarth/mindearth/cell/demnet/demnet.py index 98bd1c6bb2bc99c487175fe1b2ce1af0ede69535..b1981eb00048d19be27d35ccb5dcd12f91b540ee 100644 --- a/MindEarth/mindearth/cell/demnet/demnet.py +++ b/MindEarth/mindearth/cell/demnet/demnet.py @@ -128,6 +128,8 @@ class DEMNet(nn.Cell): self.conv_up = nn.Conv2d(out_channels, out_channels, kernel_size, pad_mode='same') self.conv_out = nn.Conv2d(out_channels, in_channels, kernel_size, pad_mode='same') self.body = self.make_layer(ResBlock, num_blocks) + self.resizebilinear = ms.ops.ResizeBilinearV2() + def make_layer(self, block, layers): res_block = [] @@ -141,6 +143,6 @@ class DEMNet(nn.Cell): out = self.conv2(out) out += x out = self.conv_up(out) - out = ms.nn.ResizeBilinear()(out, scale_factor=self.scale) + out = self.resizebilinear(out, (out.shape[2]*self.scale, out.shape[3]*self.scale)) out = self.conv_out(out) return out diff --git a/MindEarth/mindearth/cell/dgmr/dgmr.py b/MindEarth/mindearth/cell/dgmr/dgmr.py index 2adf22623a97cdc40496635e578f3e9039943ba3..9a327fd10ea41c7bcdadcb4f3b4271260bbe7c65 100644 --- a/MindEarth/mindearth/cell/dgmr/dgmr.py +++ b/MindEarth/mindearth/cell/dgmr/dgmr.py @@ -20,7 +20,7 @@ import numpy as np import mindspore as ms from mindspore import set_seed import mindspore.nn.probability.distribution as msd -from mindspore import nn, ops, Tensor, Parameter +from mindspore import nn, ops, Tensor, Parameter, mint from mindearth.cell.utils import SpectralNorm, PixelUnshuffle, PixelShuffle @@ -35,7 +35,7 @@ def get_conv_layer(conv_type="standard"): elif conv_type == "coord": conv_layer = CoordConv elif conv_type == "3d": - conv_layer = nn.Conv3d + conv_layer = mint.nn.Conv3d else: raise ValueError(f"{conv_type} is not a recognized Conv method") return conv_layer @@ -311,43 +311,74 @@ class DBlock(nn.Cell): conv2d = get_conv_layer(conv_type) if conv_type == "3d": # 3D Average pooling - self.pooling = ops.AvgPool3D(kernel_size=2, strides=2) + self.pooling = ops.MaxPool3D(kernel_size=2, strides=2) + self.conv_1x1 = conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=1, + bias=True + ) + if use_spectral_norm: + self.conv_1x1 = SpectralNorm( + self.conv_1x1 + ) + + self.first_conv_3x3 = conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + padding=1, + bias=True + ) + if use_spectral_norm: + self.first_conv_3x3 = SpectralNorm( + self.first_conv_3x3 + ) + + self.last_conv_3x3 = conv2d( + in_channels=out_channels, + out_channels=out_channels, + kernel_size=3, + padding=1, + stride=1, + bias=True + ) else: self.pooling = nn.AvgPool2d(kernel_size=2, stride=2) - self.conv_1x1 = conv2d( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=1, - has_bias=True - ) - if use_spectral_norm: - self.conv_1x1 = SpectralNorm( - self.conv_1x1 + self.conv_1x1 = conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=1, + has_bias=True ) + if use_spectral_norm: + self.conv_1x1 = SpectralNorm( + self.conv_1x1 + ) - self.first_conv_3x3 = conv2d( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=3, - padding=1, - pad_mode="pad", - has_bias=True - ) - if use_spectral_norm: - self.first_conv_3x3 = SpectralNorm( - self.first_conv_3x3 + self.first_conv_3x3 = conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + padding=1, + pad_mode="pad", + has_bias=True ) + if use_spectral_norm: + self.first_conv_3x3 = SpectralNorm( + self.first_conv_3x3 + ) - self.last_conv_3x3 = conv2d( - in_channels=out_channels, - out_channels=out_channels, - kernel_size=3, - padding=1, - pad_mode="pad", - stride=1, - has_bias=True - ) + self.last_conv_3x3 = conv2d( + in_channels=out_channels, + out_channels=out_channels, + kernel_size=3, + padding=1, + pad_mode="pad", + stride=1, + has_bias=True + ) self.relu = nn.ReLU() @@ -723,7 +754,7 @@ class UpsampleGBlock(nn.Cell): eps=spectral_normalized_eps, ) - self.upsample = nn.ResizeBilinear() + self.upsample = ops.ResizeBilinearV2(align_corners=True) # Upsample 2D conv self.first_conv_3x3 = conv2d( in_channels=in_channels, @@ -751,12 +782,14 @@ class UpsampleGBlock(nn.Cell): def construct(self, x): """UpsampleGBlock forward function""" # Spectrally normalized 1x1 convolution - sc = self.upsample(x, scale_factor=2, align_corners=True) + shape = x.shape + sc = self.upsample(x, (2*shape[2], 2*shape[3])) sc = self.conv_1x1(sc) x2 = self.bn1(x) x2 = self.relu(x2) # Upsample - x2 = self.upsample(x2, scale_factor=2, align_corners=True) + shape = x2.shape + x2 = self.upsample(x2, (2*shape[2], 2*shape[3])) x2 = self.first_conv_3x3(x2) # Make sure size is doubled x2 = self.bn2(x2) x2 = self.relu(x2) @@ -935,7 +968,7 @@ class TemporalDiscriminator(nn.Cell): conv_type="standard", use_spectral_norm=True): super().__init__() - self.downsample = ops.AvgPool3D(kernel_size=(1, 2, 2), strides=(1, 2, 2)) + self.downsample = ops.MaxPool3D(kernel_size=(1, 2, 2), strides=(1, 2, 2)) self.space2depth = PixelUnshuffle(downscale_factor=2) hidden_channels = 48 self.d1 = DBlock( @@ -1025,7 +1058,7 @@ class SpatialDiscriminator(nn.Cell): super().__init__() self.num_timesteps = num_timesteps self.mean_pool = nn.AvgPool2d(kernel_size=2, stride=2) - self.downsample = ops.AvgPool3D(kernel_size=(1, 2, 2), strides=(1, 2, 2)) + self.downsample = ops.MaxPool3D(kernel_size=(1, 2, 2), strides=(1, 2, 2)) self.space2depth = PixelUnshuffle(downscale_factor=2) hidden_channels = 24 self.d1 = DBlock( diff --git a/MindEarth/mindearth/cell/dgmr/dgmrnet.py b/MindEarth/mindearth/cell/dgmr/dgmrnet.py index 8985581f41a38205b0db1b46efc3e97f000234b6..411dd57fa3ceb416e28e4cf0d0b675d3a6f358c2 100644 --- a/MindEarth/mindearth/cell/dgmr/dgmrnet.py +++ b/MindEarth/mindearth/cell/dgmr/dgmrnet.py @@ -133,7 +133,6 @@ class DgmrGenerator(Cell): """ def __init__( self, - z, forecast_steps=18, in_channels=1, out_channels=256, @@ -147,7 +146,6 @@ class DgmrGenerator(Cell): self.context_channels = context_channels self.in_channels = in_channels self.generation_steps = generation_steps - self.z = z self.conditioning_stack = ContextConditioningStack( in_channels=in_channels, conv_type=conv_type, @@ -167,6 +165,6 @@ class DgmrGenerator(Cell): def construct(self, x): """Dgmr generator forward function.""" conditioning_states = self.conditioning_stack(x) - latent_dim = self.latent_stack(x, self.z) + latent_dim = self.latent_stack(x) output = self.sampler(conditioning_states, latent_dim) return output diff --git a/tests/st/mindearth/module/test_forecast.py b/tests/st/mindearth/module/test_forecast.py index 9e9cd19967474eb95608f2433f32a997ff5fedca..bac0dedd7972ed8d14f6d6126a678dcee185eeb3 100644 --- a/tests/st/mindearth/module/test_forecast.py +++ b/tests/st/mindearth/module/test_forecast.py @@ -68,7 +68,7 @@ class MyInference(WeatherForecast): def forecast(self, inputs): pred_lst = [] - for _ in range(self.t_out_test): + for _ in range(self.t_out): pred = self.model(inputs) pred_lst.append(pred) inputs = pred diff --git a/tests/st/mindearth/module/test_pretrain.py b/tests/st/mindearth/module/test_pretrain.py index c3619bd48fc278782fbd0993b4f5a15de87f8f5d..8c2136613d6edfe67ade41d2aa580facc4f3b512 100644 --- a/tests/st/mindearth/module/test_pretrain.py +++ b/tests/st/mindearth/module/test_pretrain.py @@ -73,6 +73,9 @@ class MyTrainer(Trainer): dataset = ds.GeneratorDataset(source=MyIterable(), column_names=["inputs", "labels"]) return dataset, dataset + def get_data_generator(self): + train_dataset_generator, valid_dataset_generator = self.get_dataset() + return train_dataset_generator, valid_dataset_generator @pytest.mark.level0 @platform_arm_ascend910b_training diff --git a/tests/st/mindearth/test_config.yaml b/tests/st/mindearth/test_config.yaml index 9073c5b66dfe6b4de21475514537ad5060e7ff8b..e752d0b747c9427ca5c2700eb0a3953977b96f77 100644 --- a/tests/st/mindearth/test_config.yaml +++ b/tests/st/mindearth/test_config.yaml @@ -46,6 +46,7 @@ summary: save_checkpoint_steps: 10 keep_checkpoint_max: 10 plt_key_info: False + save_checkpoint_epochs: 10 key_info_timestep: [6,] train: name: "oop"