diff --git a/CombinedMethod/.idea/.gitignore b/CombinedMethod/.idea/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..26d33521af10bcc7fd8cea344038eaaeb78d0ef5
--- /dev/null
+++ b/CombinedMethod/.idea/.gitignore
@@ -0,0 +1,3 @@
+# Default ignored files
+/shelf/
+/workspace.xml
diff --git a/CombinedMethod/.idea/.name b/CombinedMethod/.idea/.name
new file mode 100644
index 0000000000000000000000000000000000000000..54753c0a647e732a0e77fa6200c2327e025730ca
--- /dev/null
+++ b/CombinedMethod/.idea/.name
@@ -0,0 +1 @@
+CombinedMethods-Sink.py
\ No newline at end of file
diff --git a/CombinedMethod/.idea/deployment.xml b/CombinedMethod/.idea/deployment.xml
new file mode 100644
index 0000000000000000000000000000000000000000..b388b64d771254ae75b62823e7120fcd4345181a
--- /dev/null
+++ b/CombinedMethod/.idea/deployment.xml
@@ -0,0 +1,24 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/CombinedMethod/.idea/inspectionProfiles/Project_Default.xml b/CombinedMethod/.idea/inspectionProfiles/Project_Default.xml
new file mode 100644
index 0000000000000000000000000000000000000000..feb03d190466500fb3efaff5e228ad282dfc79a7
--- /dev/null
+++ b/CombinedMethod/.idea/inspectionProfiles/Project_Default.xml
@@ -0,0 +1,12 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/CombinedMethod/.idea/inspectionProfiles/profiles_settings.xml b/CombinedMethod/.idea/inspectionProfiles/profiles_settings.xml
new file mode 100644
index 0000000000000000000000000000000000000000..105ce2da2d6447d11dfe32bfb846c3d5b199fc99
--- /dev/null
+++ b/CombinedMethod/.idea/inspectionProfiles/profiles_settings.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/CombinedMethod/.idea/misc.xml b/CombinedMethod/.idea/misc.xml
new file mode 100644
index 0000000000000000000000000000000000000000..f5df5dc2714b8e274062ac0a16cfefc7f7c87b77
--- /dev/null
+++ b/CombinedMethod/.idea/misc.xml
@@ -0,0 +1,10 @@
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/CombinedMethod/.idea/modules.xml b/CombinedMethod/.idea/modules.xml
new file mode 100644
index 0000000000000000000000000000000000000000..57f9803524091915161c65b208e63b0d9705bd05
--- /dev/null
+++ b/CombinedMethod/.idea/modules.xml
@@ -0,0 +1,8 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/CombinedMethod/.idea/new618.iml b/CombinedMethod/.idea/new618.iml
new file mode 100644
index 0000000000000000000000000000000000000000..49830d77e1ea2fabf43b8dc06adb580bab524ac3
--- /dev/null
+++ b/CombinedMethod/.idea/new618.iml
@@ -0,0 +1,12 @@
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/CombinedMethod/CombinedMethods-Sink.py b/CombinedMethod/CombinedMethods-Sink.py
new file mode 100644
index 0000000000000000000000000000000000000000..566144116306bc35b62cf04009ab94c929428846
--- /dev/null
+++ b/CombinedMethod/CombinedMethods-Sink.py
@@ -0,0 +1,160 @@
+# 项目初始化
+import os
+import time
+
+import numpy as np
+from mindspore import nn, Tensor, context, ops, jit, set_seed, data_sink, save_checkpoint
+from mindspore import dtype as mstype
+from mindflow.common import get_warmup_cosine_annealing_lr
+from mindflow.loss import RelativeRMSELoss
+from mindspore.nn import L1Loss
+from mindflow.utils import load_yaml_config, print_log
+
+from src.utils import Trainer, init_model, check_file_path, count_params, plot_image, plot_image_1
+from src.dataset import init_dataset
+from src.visual import plt_log
+
+set_seed(0)
+np.random.seed(0)
+
+context.set_context(mode=context.GRAPH_MODE,
+ save_graphs=False,
+ device_target="Ascend",
+ device_id=0)
+use_ascend = context.get_context("device_target") == "Ascend"
+print(use_ascend)
+
+# 配置训练参数
+config = load_yaml_config("./config/combined_methods.yaml")
+data_params = config["data"]
+model_params = config["model"]
+optimizer_params = config["optimizer"]
+summary_params = config["summary"]
+
+# 准备数据集
+train_dataset, test_dataset, means, stds = init_dataset(data_params)
+print('train_dataset', train_dataset)
+# print(train_dataset.create_tuple_iterator())
+
+# 模型构建
+if use_ascend:
+ from mindspore.amp import DynamicLossScaler, all_finite, auto_mixed_precision
+
+ loss_scaler = DynamicLossScaler(1024, 2, 100)
+ compute_dtype = mstype.float16
+ model = init_model("unet2d", data_params, model_params, compute_dtype=compute_dtype)
+ auto_mixed_precision(model, optimizer_params["amp_level"]["unet2d"])
+else:
+ context.set_context(enable_graph_kernel=True)
+ loss_scaler = None
+ compute_dtype = mstype.float32
+ model = init_model("unet2d", data_params, model_params, compute_dtype=compute_dtype)
+
+# 损失函数与优化器
+# loss_fn = RelativeRMSELoss()
+loss_fn = L1Loss()
+
+summary_dir = os.path.join(summary_params["summary_dir"], "Exp_datadriven", "unet2d")
+ckpt_dir = os.path.join(summary_dir, "ckpt_dir")
+check_file_path(ckpt_dir)
+check_file_path(os.path.join(ckpt_dir, 'img'))
+print_log('model parameter count:', count_params(model.trainable_params()))
+print_log(
+ f'learing rate: {optimizer_params["lr"]["unet2d"]}, T_in: {data_params["T_in"]}, T_out: {data_params["T_out"]}')
+steps_per_epoch = train_dataset.get_dataset_size()
+
+lr = get_warmup_cosine_annealing_lr(optimizer_params["lr"]["unet2d"], steps_per_epoch,
+ optimizer_params["epochs"], optimizer_params["warm_up_epochs"])
+optimizer = nn.AdamWeightDecay(model.trainable_params(),
+ learning_rate=Tensor(lr),
+ weight_decay=optimizer_params["weight_decay"])
+
+# 训练函数
+trainer = Trainer(model, data_params, loss_fn, means, stds)
+
+
+# 执行模型的前向传播并返回损失值
+def forward_fn(inputs, labels):
+ # loss, l_data, l_phy, _, _, _, _ = trainer.get_loss(inputs, labels)
+ loss, _, _, _, _, _, _ = trainer.get_loss(inputs, labels)
+ if use_ascend:
+ loss = loss_scaler.scale(loss)
+ return loss
+
+
+# 计算前向函数 forward_fn 的输出(即损失)相对于模型参数的梯度。
+grad_fn = ops.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=False)
+
+
+# 用于通过反向传播算法优化模型以最小化损失函数。
+@jit
+def train_step(inputs, labels):
+ loss, grads = grad_fn(inputs, labels)
+ if use_ascend:
+ loss = loss_scaler.unscale(loss)
+ if all_finite(grads):
+ grads = loss_scaler.unscale(grads)
+ loss_new = ops.depend(loss, optimizer(grads))
+ return loss_new
+
+
+# get_loss 方法负责执行模型的前向传播,计算预测输出与真实标签之间的差异,并据此计算损失。
+def test_step(inputs, labels):
+ return trainer.get_loss(inputs, labels)
+
+
+# 获取训练数据集的大小,即训练数据集中的批次总数。这个值用于确定训练循环中的迭代次数。
+train_size = train_dataset.get_dataset_size()
+test_size = test_dataset.get_dataset_size()
+# data_sink 是 MindSpore 中的一个功能,它可以并行地执行数据的加载和处理,以及模型的前向和后向传播。
+train_sink = data_sink(train_step, train_dataset, sink_size=1)
+test_sink = data_sink(test_step, test_dataset, sink_size=1)
+# test_interval 指定了在训练过程中每隔多少个epoch执行一次测试
+test_interval = summary_params["test_interval"]
+# save_ckpt_interval 指定了在训练过程中每隔多少个epoch保存一次模型的检查点(checkpoint)。
+save_ckpt_interval = summary_params["save_ckpt_interval"]
+
+# 模型训练
+for epoch in range(1, optimizer_params["epochs"] + 1):
+ time_beg = time.time()
+ train_l1 = 0.0
+ # 将模型设置为训练模式
+ model.set_train()
+ for step in range(1, train_size + 1):
+ # 返回当前批次的损失值
+ loss_train = train_sink()
+ train_l1 += loss_train.asnumpy()
+ train_loss = train_l1 / train_size
+ print_log(
+ f"epoch: {epoch}, step time: {(time.time() - time_beg) / steps_per_epoch:>7f}, loss: {train_loss:>7f}")
+
+ if epoch % test_interval == 0:
+ model.set_train(False)
+ test_l1 = 0.0
+ # 用于在测试集上评估模型的性能
+ for step in range(test_size):
+ loss_test, loss1, loss2, inputs, pred, labels, step_losses = test_sink()
+ test_l1 += loss_test.asnumpy()
+ test_loss = test_l1 / test_size
+ print_log(
+ f"epoch: {epoch}, step time: {(time.time() - time_beg) / steps_per_epoch:>7f}, loss: {test_loss:>7f}")
+
+ # 调用可视化函数来绘制图像,通常用于在训练或测试过程中观察模型的输入、预测输出和真实标签。
+ plot_image(inputs, 0)
+ plot_image_1(inputs, 0)
+ plot_image(pred, 0)
+ plot_image(labels, 0)
+
+ # 更新滑动窗口中的损失
+ if epoch >= trainer.window_size:
+ # 调用get_loss获取当前损失值
+ loss_total, loss1, loss2, _, _, _, _ = trainer.get_loss(inputs, labels)
+ # 更新滑动窗口中的损失
+ trainer.update_loss_lists(loss1, loss2)
+ # 调整权重
+ trainer.adjust_weights(epoch)
+
+ if epoch % save_ckpt_interval == 0:
+ save_checkpoint(model, ckpt_file_name=os.path.join(ckpt_dir, 'model_data.ckpt'))
+
+print("Training Finished!!")
diff --git a/CombinedMethod/config/combined_methods.yaml b/CombinedMethod/config/combined_methods.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..d5878fae2957e275fe5886960239dcb4a3d9030d
--- /dev/null
+++ b/CombinedMethod/config/combined_methods.yaml
@@ -0,0 +1,48 @@
+model:
+ # in_channels: 3
+ # out_channels: 3
+ in_channels: 2
+ out_channels: 1
+ resolution: 128
+ load_ckpt: False
+ unet2d:
+ kernel_size: 2
+ stride: 2
+ base_channels: 64
+ fno2d:
+ modes: 12
+ channels: 64
+ depths: 4
+ mlp_ratio: 4
+
+data:
+ root_dir: "./dataset/"
+ dataset_name: "S_0_0_10_5_10_5_5.npz"
+ data_size: 3600
+ train_size: 3000
+ means: [0.5803, 0.0401, 0.6564]
+ stds: [0.3523, 0.0885, 0.1459]
+ T_in: 8
+ T_out: 32
+ train_batch_size: 8
+ test_batch_size: 8
+
+optimizer:
+ epochs: 2000
+ save_epoch: 100
+ warm_up_epochs: 1
+ gamma: 0.2
+ weight_decay: 0.001
+ amp_level:
+ fno2d: "O2"
+ unet2d: "O1"
+ lr:
+ fno2d: 0.001
+ unet2d: 0.0001
+
+summary:
+ summary_dir: "./summary_dir/"
+ # pretrained_ckpt_dir: "path/to/ckpt"
+ pretrained_ckpt_dir: "./summary_dir/ckpt"
+ save_ckpt_interval: 10
+ test_interval: 10
\ No newline at end of file
diff --git a/CombinedMethod/dataset/S_0_0_10_10_10_0_0.npz b/CombinedMethod/dataset/S_0_0_10_10_10_0_0.npz
new file mode 100644
index 0000000000000000000000000000000000000000..88f7a2c0f92eb47e9786c30cbfd520b7e8162055
Binary files /dev/null and b/CombinedMethod/dataset/S_0_0_10_10_10_0_0.npz differ
diff --git a/CombinedMethod/dataset/S_0_0_10_10_10_0_10.npz b/CombinedMethod/dataset/S_0_0_10_10_10_0_10.npz
new file mode 100644
index 0000000000000000000000000000000000000000..a6481e9d59e3e20eba29d70096b116219b63df82
Binary files /dev/null and b/CombinedMethod/dataset/S_0_0_10_10_10_0_10.npz differ
diff --git a/CombinedMethod/dataset/S_0_0_10_10_10_0_5.npz b/CombinedMethod/dataset/S_0_0_10_10_10_0_5.npz
new file mode 100644
index 0000000000000000000000000000000000000000..233b74e85296e02b9182770a9df82d232b2c9fe0
Binary files /dev/null and b/CombinedMethod/dataset/S_0_0_10_10_10_0_5.npz differ
diff --git a/CombinedMethod/dataset/S_0_0_10_10_10_10_0.npz b/CombinedMethod/dataset/S_0_0_10_10_10_10_0.npz
new file mode 100644
index 0000000000000000000000000000000000000000..36a4c6fc9bb650abf4668924b45992d1fe81a42d
Binary files /dev/null and b/CombinedMethod/dataset/S_0_0_10_10_10_10_0.npz differ
diff --git a/CombinedMethod/dataset/S_0_0_10_10_10_10_10.npz b/CombinedMethod/dataset/S_0_0_10_10_10_10_10.npz
new file mode 100644
index 0000000000000000000000000000000000000000..6e8ca80559de4af103c9a8b5faac4a71b3e9c17f
Binary files /dev/null and b/CombinedMethod/dataset/S_0_0_10_10_10_10_10.npz differ
diff --git a/CombinedMethod/dataset/S_0_0_10_10_10_10_5.npz b/CombinedMethod/dataset/S_0_0_10_10_10_10_5.npz
new file mode 100644
index 0000000000000000000000000000000000000000..cc7b2d9606113bd32e4cb12f06c11ceeb79d214b
Binary files /dev/null and b/CombinedMethod/dataset/S_0_0_10_10_10_10_5.npz differ
diff --git a/CombinedMethod/dataset/S_0_0_10_10_10_5_0.npz b/CombinedMethod/dataset/S_0_0_10_10_10_5_0.npz
new file mode 100644
index 0000000000000000000000000000000000000000..38902ca4d1f137d3e34102216d97fdf757708ebb
Binary files /dev/null and b/CombinedMethod/dataset/S_0_0_10_10_10_5_0.npz differ
diff --git a/CombinedMethod/dataset/S_0_0_10_10_10_5_10.npz b/CombinedMethod/dataset/S_0_0_10_10_10_5_10.npz
new file mode 100644
index 0000000000000000000000000000000000000000..978a99c0840217a2a25af73d8332b2a35b7ecd89
Binary files /dev/null and b/CombinedMethod/dataset/S_0_0_10_10_10_5_10.npz differ
diff --git a/CombinedMethod/dataset/S_0_0_10_10_10_5_5.npz b/CombinedMethod/dataset/S_0_0_10_10_10_5_5.npz
new file mode 100644
index 0000000000000000000000000000000000000000..d2ef7c87fc7d70296312217b7385fced2cfae8bd
Binary files /dev/null and b/CombinedMethod/dataset/S_0_0_10_10_10_5_5.npz differ
diff --git a/CombinedMethod/dataset/S_0_0_10_10_5_10_10.npz b/CombinedMethod/dataset/S_0_0_10_10_5_10_10.npz
new file mode 100644
index 0000000000000000000000000000000000000000..cd21d0115ef7cc881f8fe7b230e1339392ccc65a
Binary files /dev/null and b/CombinedMethod/dataset/S_0_0_10_10_5_10_10.npz differ
diff --git a/CombinedMethod/src/dataset.py b/CombinedMethod/src/dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..39e62ee4c7c221803cc5b80c45357e1f8536a8ec
--- /dev/null
+++ b/CombinedMethod/src/dataset.py
@@ -0,0 +1,98 @@
+# Copyright 2023 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""
+dataset
+"""
+import os
+
+import numpy as np
+import mindspore.dataset as ds
+from mindspore.common import dtype as mstype
+
+
+def init_dataset(data_params):
+ """initial_dataset"""
+ dataset = DataSet(data_params)
+ print("initialing dataset")
+ train_dataset, test_dataset, means, stds = dataset.create_dataset()
+ return train_dataset, test_dataset, means, stds
+
+
+class HeatConductionData:
+ """HeatConductionData"""
+
+ def __init__(self, data_path, t_in, t_out):
+ input_list = []
+ target_list = []
+ # dataset = []
+ files = os.listdir(data_path)
+ for file_name in files:
+ if file_name.endswith(".npz"):
+ file_path = os.path.join(data_path, file_name)
+ data = np.load(file_path)
+
+ input_data = data['a'][0:2].astype(np.float32)
+ target_data = data['a'][2:3].astype(np.float32)
+ input_transposed = np.transpose(input_data,(1,2,0))
+ target_transposed = np.transpose(target_data,(1,2,0))
+
+ input_list.append(input_transposed)
+ target_list.append(target_transposed)
+ self.inputs = np.array(input_list).astype(np.float32)
+ self.labels = np.array(target_list).astype(np.float32)
+
+ print("input size", self.inputs.shape)
+ print("label size", self.labels.shape)
+
+ def __len__(self):
+ return len(self.inputs)
+
+ def __getitem__(self, idx):
+ return self.inputs[idx], self.labels[idx]
+
+
+class DataSet:
+ """DataSet"""
+
+ def __init__(self, data_params):
+ self.data_path = data_params['root_dir']
+ self.dataset_name = data_params['dataset_name']
+ self.t_in = data_params['T_in']
+ self.t_out = data_params['T_out']
+ self.train_batch_size = data_params['train_batch_size']
+ self.test_batch_size = data_params['test_batch_size']
+
+ self.dataset_generator = HeatConductionData(self.data_path, t_in=self.t_in, t_out=self.t_out)
+ self.mean_inputs = np.array(data_params['means']).astype('float32')
+ self.std_inputs = np.array(data_params['stds']).astype('float32')
+ self.data_size = data_params['data_size']
+ self.train_size = data_params['train_size']
+
+ def create_dataset(self, drop_remainder=True):
+ """create dataset"""
+ dataset = ds.GeneratorDataset(self.dataset_generator, ["inputs", "labels"], shuffle=False)
+ train_ds, test_ds = dataset.split([self.train_size / self.data_size, 1 - self.train_size / self.data_size],
+ randomize=False)
+
+ print("train_batch_size : {}".format(self.train_batch_size))
+
+ data_set_batch_train = train_ds.batch(self.train_batch_size, drop_remainder=drop_remainder)
+ data_set_batch_test = test_ds.batch(self.test_batch_size, drop_remainder=drop_remainder)
+ print("train batch dataset size: {}".format(data_set_batch_train.get_dataset_size()))
+ print("test batch dataset size: {}".format(data_set_batch_test.get_dataset_size()))
+ return data_set_batch_train, data_set_batch_test, self.mean_inputs, self.std_inputs
+
+ def process_fn(self, inputs):
+ return (inputs - self.mean_inputs) / (self.std_inputs + 1e-10)
\ No newline at end of file
diff --git a/CombinedMethod/src/fno2d.py b/CombinedMethod/src/fno2d.py
new file mode 100644
index 0000000000000000000000000000000000000000..e471c700c12686b3e0d9924b6939adfce6735685
--- /dev/null
+++ b/CombinedMethod/src/fno2d.py
@@ -0,0 +1,221 @@
+# Copyright 2023 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""
+fno2d
+"""
+import numpy as np
+import mindspore.common.dtype as mstype
+from mindspore import ops, nn, Tensor, Parameter
+from mindspore.ops import operations as P
+from mindspore.common.initializer import Zero
+
+from mindflow.utils.check_func import check_param_type
+from mindflow.common.math import get_grid_2d
+from mindflow.cell.neural_operators.dft import dft2, idft2
+
+
+class FNO2D(nn.Cell):
+ r"""
+ The 2-dimensional Fourier Neural Operator (FNO2D) contains a lifting layer,
+ multiple Fourier layers and a decoder layer.
+ The details can be found in `Fourier neural operator for parametric
+ partial differential equations `_.
+
+ Args:
+ in_channels (int): The number of channels in the input space.
+ out_channels (int): The number of channels in the output space.
+ resolution (int): The spatial resolution of the input.
+ modes (int): The number of low-frequency components to keep.
+ channels (int): The number of channels after dimension lifting of the input. Default: 20.
+ depths (int): The number of FNO layers. Default: 4.
+ mlp_ratio (int): The number of channels lifting ratio of the decoder layer. Default: 4.
+ compute_dtype (dtype.Number): The computation type of dense layer.
+ Default mstype.float16.
+ Should be mstype.float16 or mstype.float32.
+ mstype.float32 is recommended for the GPU backend, mstype.float16 is recommended for the Ascend backend.
+
+ Inputs:
+ - **x** (Tensor) - Tensor of shape :math:`(batch_size, resolution, resolution, in_channels)`.
+
+ Outputs:
+ Tensor, the output of this FNO network.
+
+ - **output** (Tensor) - Tensor of shape :math:`(batch_size, resolution, resolution, out_channels)`.
+ - grid (Tensor) - Tensor of shape :`(1, resolution, resolution, 2)`
+
+ Raises:
+ TypeError: If `in_channels` is not an int.
+ TypeError: If `out_channels` is not an int.
+ TypeError: If `resolution` is not an int.
+ TypeError: If `modes` is not an int.
+ ValueError: If `modes` is less than 1.
+
+ Supported Platforms:
+ ``Ascend`` ``GPU``
+
+ Examples:
+ >>> import numpy as np
+ >>> from mindspore.common.initializer import initializer, Normal
+ >>> from mindflow.cell.neural_operators import FNO2D
+ >>> B, H, W, C = 32, 64, 64, 1
+ >>> input = initializer(Normal(), [B, H, W, C])
+ >>> net = FNO2D(in_channels=1, out_channels=1, resolution=64, modes=12)
+ >>> output = net(input)
+ >>> print(output.shape)
+ (32, 64, 64, 1)
+
+ """
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ resolution,
+ modes,
+ channels=20,
+ depths=4,
+ mlp_ratio=4,
+ compute_dtype=mstype.float32):
+ super().__init__()
+ check_param_type(in_channels, "in_channels",
+ data_type=int, exclude_type=bool)
+ check_param_type(out_channels, "out_channels",
+ data_type=int, exclude_type=bool)
+ check_param_type(resolution, "resolution",
+ data_type=int, exclude_type=bool)
+ check_param_type(modes, "modes", data_type=int, exclude_type=bool)
+ if modes < 1:
+ raise ValueError("modes must at least 1, but got mode: {}".format(modes))
+ self.compute_dtype = compute_dtype
+
+ self.modes1 = modes
+ self.channels = channels
+ self.fc_channel = mlp_ratio * channels
+ self.fc0 = nn.Dense(in_channels + 2, self.channels, has_bias=True,
+ weight_init='Uniform', bias_init='Uniform').to_float(self.compute_dtype)
+ self.layers = depths
+
+ self.fno_seq = nn.SequentialCell()
+ for _ in range(self.layers):
+ self.fno_seq.append(FNOBlock(self.channels, self.channels, modes1=self.modes1,
+ resolution=resolution, compute_dtype=self.compute_dtype))
+
+ self.fc1 = nn.Dense(self.channels, 128, has_bias=True, weight_init='Uniform',
+ bias_init='Uniform').to_float(self.compute_dtype)
+ self.fc2 = nn.Dense(128, out_channels, has_bias=True, weight_init='Uniform',
+ bias_init='Uniform').to_float(self.compute_dtype)
+
+ self.grid = Tensor(get_grid_2d(resolution), self.compute_dtype)
+ self.concat = ops.Concat(axis=-1)
+ self.act = ops.ReLU()
+
+ def construct(self, x: Tensor):
+ """forward"""
+ batch_size = x.shape[0]
+ grid = self.grid.repeat(batch_size, axis=0)
+ x = P.Concat(-1)((x, grid))
+ x = self.fc0(x)
+ x = P.Transpose()(x, (0, 3, 1, 2))
+ x = self.fno_seq(x)
+ x = P.Transpose()(x, (0, 2, 3, 1))
+ x = self.fc1(x)
+ x = self.act(x)
+ output = self.fc2(x)
+ return output
+
+
+class FNOBlock(nn.Cell):
+ """FNOBlock"""
+ def __init__(self, in_channels, out_channels, modes1, resolution=128, compute_dtype=mstype.float32):
+ super().__init__()
+ self.compute_dtype = compute_dtype
+ self.conv = SpectralConv2dDft(in_channels, out_channels, modes1, modes1, resolution,
+ resolution, compute_dtype=mstype.float32)
+ self.w = nn.Conv2d(in_channels, out_channels, 1, has_bias=True,
+ weight_init='HeUniform').to_float(self.compute_dtype)
+ self.act = ops.ReLU()
+
+ def construct(self, x):
+ return self.act(self.conv(x) + self.w(x))
+
+
+class SpectralConv2dDft(nn.Cell):
+ """SpectralConv2dDft"""
+ def __init__(self, in_channels, out_channels, modes1, modes2, column_resolution, raw_resolution,
+ compute_dtype=mstype.float32):
+ super().__init__()
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.modes1 = modes1
+ self.modes2 = modes2
+ self.column_resolution = column_resolution
+ self.raw_resolution = raw_resolution
+ self.compute_dtype = compute_dtype
+ self.scale = (1. / (in_channels * out_channels))
+
+ w_re1 = Tensor(self.scale * np.random.rand(in_channels, out_channels, modes1, modes2),
+ dtype=mstype.float32)
+ w_im1 = Tensor(self.scale * np.random.rand(in_channels, out_channels, modes1, modes2),
+ dtype=mstype.float32)
+ w_re2 = Tensor(self.scale * np.random.rand(in_channels, out_channels, modes1, modes2),
+ dtype=mstype.float32)
+ w_im2 = Tensor(self.scale * np.random.rand(in_channels, out_channels, modes1, modes2),
+ dtype=mstype.float32)
+
+ self.w_re1 = Parameter(w_re1, requires_grad=True)
+ self.w_im1 = Parameter(w_im1, requires_grad=True)
+ self.w_re2 = Parameter(w_re2, requires_grad=True)
+ self.w_im2 = Parameter(w_im2, requires_grad=True)
+ self.dft2_cell = dft2(shape=(column_resolution, raw_resolution),
+ modes=(modes1, modes2), compute_dtype=self.compute_dtype)
+ self.idft2_cell = idft2(shape=(column_resolution, raw_resolution),
+ modes=(modes1, modes2), compute_dtype=self.compute_dtype)
+ self.mat = Tensor(shape=(1, out_channels, column_resolution - 2 * modes1, modes2),
+ dtype=self.compute_dtype, init=Zero())
+ self.concat = ops.Concat(-2)
+
+ @staticmethod
+ def mul2d(inputs, weights):
+ weight = weights.expand_dims(0)
+ data = inputs.expand_dims(2)
+ out = weight * data
+ return out.sum(1)
+
+ def construct(self, x: Tensor):
+ """forward"""
+ x_re = x
+ x_im = ops.zeros_like(x_re)
+ x_ft_re, x_ft_im = self.dft2_cell((x_re, x_im))
+
+ out_ft_re1 = \
+ self.mul2d(x_ft_re[:, :, :self.modes1, :self.modes2], self.w_re1) \
+ - self.mul2d(x_ft_im[:, :, :self.modes1, :self.modes2], self.w_im1)
+ out_ft_im1 = \
+ self.mul2d(x_ft_re[:, :, :self.modes1, :self.modes2], self.w_im1) \
+ + self.mul2d(x_ft_im[:, :, :self.modes1, :self.modes2], self.w_re1)
+
+ out_ft_re2 = \
+ self.mul2d(x_ft_re[:, :, -self.modes1:, :self.modes2], self.w_re2) \
+ - self.mul2d(x_ft_im[:, :, -self.modes1:, :self.modes2], self.w_im2)
+ out_ft_im2 = \
+ self.mul2d(x_ft_re[:, :, -self.modes1:, :self.modes2], self.w_im2) \
+ + self.mul2d(x_ft_im[:, :, -self.modes1:, :self.modes2], self.w_re2)
+
+ batch_size = x.shape[0]
+ mat = ops.cast(self.mat.repeat(batch_size, 0), self.compute_dtype)
+ out_re = self.concat((out_ft_re1, mat, out_ft_re2))
+ out_im = self.concat((out_ft_im1, mat, out_ft_im2))
+
+ x, _ = self.idft2_cell((out_re, out_im))
+ return x
\ No newline at end of file
diff --git a/CombinedMethod/src/unet.py b/CombinedMethod/src/unet.py
new file mode 100644
index 0000000000000000000000000000000000000000..3237356fdabc4d4b2afb01674ed904b1d60dc717
--- /dev/null
+++ b/CombinedMethod/src/unet.py
@@ -0,0 +1,153 @@
+# Copyright 2023 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""
+unet2d
+"""
+import mindspore.nn as nn
+import mindspore.ops as ops
+from mindspore.ops import operations as P
+
+
+class Unet2D(nn.Cell):
+ r"""
+ UNet2D model
+
+ Args:
+ in_channels (int): The input feature size of input.
+ out_channels (int): The output feature size of output.
+ resolution (int): The spatial resolution of the input.
+ kernel_size (int): Specifies the height and width of the 2D convolution kernel. Default: 2.
+ stride (Union[int, tuple[int]]): The distance of kernel moving,
+ an int number that represents the height and width of movement are both stride,
+ or a tuple of two int numbers that represent height and width of movement respectively. Default: 2.
+
+ Inputs:
+ - **input** (Tensor) - Tensor of shape :math:`(batch\_size, resolution, resolution, channels)`.
+
+ Outputs:
+ - **output** (Tensor) - Tensor of shape :math:`((batch\_size, resolution, resolution, channels)`.
+
+ Supported Platforms:
+ ``Ascend`` ``GPU``
+
+ Examples:
+ >>> import mindspore as ms
+ >>> from mindspore import Tensor
+ >>> import mindspore.common.dtype as mstype
+ >>> ms.set_context(mode=ms.GRAPH_MODE, save_graphs=False, device_target="GPU")
+ >>> x=Tensor(np.ones([2, 128, 128, 3]), mstype.float32)
+ >>> unet = Unet2D(in_channels=3, out_channels=3)
+ >>> output=unet(x)
+ >>> print(res_x.shape)
+ (2, 128, 128, 3)
+ """
+
+ def __init__(self, in_channels, out_channels, base_channels, kernel_size=2, stride=2):
+ super().__init__()
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.kernel_size = kernel_size
+ self.stride = stride
+ self.base_channels = base_channels
+
+ self.inc = DoubleConv(self.in_channels, self.base_channels, mid_channels=None)
+ self.down1 = Down(self.base_channels, self.base_channels * 2, self.kernel_size, self.stride)
+ self.down2 = Down(self.base_channels * 2, self.base_channels * 4, self.kernel_size, self.stride)
+ self.down3 = Down(self.base_channels * 4, self.base_channels * 8, self.kernel_size, self.stride)
+ self.down4 = Down(self.base_channels * 8, self.base_channels * 16, self.kernel_size, self.stride)
+ self.up1 = Up(self.base_channels * 16, self.base_channels * 8, self.kernel_size, self.stride)
+ self.up2 = Up(self.base_channels * 8, self.base_channels * 4, self.kernel_size, self.stride)
+ self.up3 = Up(self.base_channels * 4, self.base_channels * 2, self.kernel_size, self.stride)
+ self.up4 = Up(self.base_channels * 2, self.base_channels, self.kernel_size, self.stride)
+ self.outc = nn.Conv2d(self.base_channels + self.in_channels, self.out_channels, kernel_size=3, stride=1)
+ self.transpose = P.Transpose()
+ self.cat = P.Concat(axis=1)
+
+ def construct(self, x):
+ """forward"""
+ x0 = self.transpose(x, (0, 3, 1, 2))
+ x1 = self.inc(x0)
+ x2 = self.down1(x1)
+ x3 = self.down2(x2)
+ x4 = self.down3(x3)
+ x5 = self.down4(x4)
+ x = self.up1(x5, x4)
+ x = self.up2(x, x3)
+ x = self.up3(x, x2)
+ x = self.up4(x, x1)
+ x = self.cat((x, x0))
+ x = self.outc(x)
+ out = self.transpose(x, (0, 2, 3, 1))
+
+ return out
+
+
+class DoubleConv(nn.Cell):
+ """double conv"""
+
+ def __init__(self, in_channels, out_channels, mid_channels=None):
+ super().__init__()
+ if not mid_channels:
+ mid_channels = out_channels
+ self.double_conv = nn.SequentialCell(
+ nn.Conv2d(in_channels, mid_channels, kernel_size=3),
+ nn.BatchNorm2d(mid_channels),
+ nn.ReLU(),
+ nn.Conv2d(mid_channels, out_channels, kernel_size=3),
+ nn.BatchNorm2d(out_channels),
+ nn.ReLU()
+ )
+
+ def construct(self, x):
+ """forward"""
+ return self.double_conv(x)
+
+
+class Down(nn.Cell):
+ """down"""
+
+ def __init__(self, in_channels, out_channels, kernel_size=2, stride=2):
+ super().__init__()
+ self.conv = DoubleConv(in_channels, out_channels)
+ self.maxpool = nn.MaxPool2d(kernel_size=kernel_size, stride=stride)
+
+ def construct(self, x):
+ """forward"""
+ x = self.maxpool(x)
+ return self.conv(x)
+
+
+class Up(nn.Cell):
+ """up"""
+
+ def __init__(self, in_channels, out_channels, kernel_size=2, stride=2):
+ super().__init__()
+ self.up = nn.Conv2dTranspose(in_channels, in_channels // 2, kernel_size=kernel_size, stride=stride)
+ self.conv = DoubleConv(in_channels, out_channels)
+ self.cat = ops.Concat(axis=1)
+
+ def construct(self, x1, x2):
+ """forward"""
+ x1 = self.up(x1)
+
+ _, _, h1, w1 = ops.shape(x1)
+ _, _, h2, w2 = ops.shape(x2)
+
+ diff_y = w2 - w1
+ diff_x = h2 - h1
+
+ x1 = ops.Pad(((0, 0), (0, 0), (diff_x // 2, diff_x - diff_x // 2), (diff_y // 2, diff_y - diff_y // 2)))(x1)
+ x = self.cat((x2, x1))
+ return self.conv(x)
\ No newline at end of file
diff --git a/CombinedMethod/src/utils.py b/CombinedMethod/src/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..2e8f9fd288533d80c925aeb4f6dc351404c309a9
--- /dev/null
+++ b/CombinedMethod/src/utils.py
@@ -0,0 +1,219 @@
+# Copyright 2023 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""
+utils
+"""
+import os
+
+import numpy as np
+from scipy.signal import convolve2d
+import matplotlib.pyplot as plt
+from mindspore import nn, ops, jit_class, Tensor, Parameter
+from mindspore.common.initializer import initializer
+from mindspore import dtype as mstype
+
+from .unet import Unet2D
+from .fno2d import FNO2D
+
+
+# from .operators import Operate
+
+
+def init_model(backbone, data_params, model_params, compute_dtype=mstype.float32):
+ """initial_data_and_model"""
+ if backbone == "fno2d":
+ model = FNO2D(in_channels=model_params["in_channels"] * data_params['T_in'],
+ out_channels=model_params["out_channels"],
+ resolution=model_params["resolution"],
+ modes=model_params["fno2d"]["modes"],
+ channels=model_params["fno2d"]["channels"],
+ depths=model_params["fno2d"]["depths"],
+ mlp_ratio=model_params["fno2d"]["mlp_ratio"],
+ compute_dtype=compute_dtype)
+ else:
+ # model = Unet2D(in_channels=model_params["in_channels"] * data_params['T_in'],
+ model = Unet2D(in_channels=model_params["in_channels"],
+ out_channels=model_params["out_channels"],
+ base_channels=model_params["unet2d"]["base_channels"],
+ kernel_size=model_params["unet2d"]["kernel_size"],
+ stride=model_params["unet2d"]["stride"])
+ return model
+
+
+def check_file_path(file_path):
+ """check_file_path"""
+ if not os.path.exists(file_path):
+ os.makedirs(file_path)
+
+
+def count_params(params):
+ """count_params"""
+ count = 0
+ for p in params:
+ t = 1
+ for i in range(len(p.shape)):
+ t = t * p.shape[i]
+ count += t
+ return count
+
+
+def plot_image(tensor, sampleNo):
+ # 获取第sampleNo个样本的图像数据
+ image = tensor.asnumpy()[sampleNo, :, :, 0]
+ # 画图
+ plt.imshow(image, cmap='rainbow')
+ plt.axis('off')
+ plt.show()
+
+
+def plot_image_1(tensor, sampleNo):
+ # 获取第sampleNo个样本的图像数据
+ image = tensor.asnumpy()[sampleNo, :, :, 1]
+ # 画图
+ plt.imshow(image, cmap='rainbow')
+ plt.axis('off')
+ plt.show()
+
+
+@jit_class
+class Trainer():
+ r"""
+ Trainer
+
+ Args:
+ model (Model): The Unet2D/FNO2D model.
+ data_params (dict): The data parameters loaded from yaml file.
+ loss_fn (Tensor): The loss function.
+ means (list): The mean value of every input channel.
+ stds (list): The standard deviation value of every input channel.
+
+ Inputs:
+ - inputs (Tensor) - Tensor of shape :math:`(batch\_size*T_in, resolution, resolution, channels)`.
+ - labels (Tensor) - Tensor of shape :math:`(batch\size, resolution, resolution, channels)`.
+
+ Outputs:
+ - loss (float) - The average loss calculated by average of test step losses.
+ - loss_full (float) - The average loss directly calculated for the current batch.
+ - pred (Tensor) - Tensor of shape :math:`(batch\size, resolution, resolution, channels)`.
+ - step_losses (list) - The list of step losses with length of T_out
+ """
+
+ def __init__(self, model, data_params, loss_fn, means, stds, window_size=3):
+ self.model = model
+ self.test_steps = data_params["T_out"]
+ self.loss_fn = loss_fn
+ self.mean = Tensor(means, dtype=mstype.float32)
+ self.std = Tensor(stds, dtype=mstype.float32)
+ self.window_size = window_size # 滑动窗口的大小
+ self.loss1_list = [] # 存储数据驱动损失的列表
+ self.loss2_list = [] # 存储物理驱动损失的列表
+ self.loss1_weight = 100 # 数据驱动损失的初始权重
+ self.MAX_WEIGHT = 200 # 权重上限
+
+ def _build_features(self, inputs):
+ return inputs.astype(mstype.float32)
+
+ def get_loss(self, inputs, labels):
+ """get loss"""
+
+ embeds = self._build_features(inputs)
+ y = labels[:, :, :, :]
+
+ batch_no = embeds.shape[0]
+ loss = 0
+ pred = 0
+ step_losses = []
+
+ embeds = ops.cast(embeds, mstype.float32)
+ BCs = embeds[:, :, :, 0:1]
+ mask = embeds[:, :, :, 1:2]
+ mask_no_BCout = embeds[:, 1:127, 1:127, 1:2]
+
+ # mask_with_BCout: 128*128
+ # lineCat = np.zeros((batch_no, 126, 126, 1)).astype(np.float32) + 1
+ # mask_with_BCout = np.concatenate((lineCat[:, :, :1, :],
+ # mask_no_BCout[:, :, :, :],
+ # lineCat[:, :, :-1, :]), axis=2)
+ # columnCat = np.zeros((batch_no, 128, 128, 1)).astype(np.float32) + 1
+ # mask_with_BCout = np.concatenate((columnCat[:, :1, :, :],
+ # mask_with_BCout[:, :, :, :],
+ # columnCat[:, :-1, :, :]), axis=1)
+
+ mask_no_BCout_transposed = ops.transpose(mask_no_BCout, (0, 3, 1, 2))
+ mask_with_BCout_transposed = ops.pad(mask_no_BCout_transposed, [1, 1, 1, 1], mode='constant', value=1.0)
+ mask_with_BCout = ops.transpose(mask_with_BCout_transposed, (0, 2, 3, 1))
+
+ im_org = self.model(embeds)
+ im_org = ops.cast(im_org, mstype.float32)
+ pred = im_org * (1 - mask_with_BCout) + BCs
+ # pred = im_org * (1 - mask) + BCs
+
+ # data_driven loss
+ l_data = self.loss_fn(pred, y)
+
+ # physics_driven loss
+ pred_tensor = ops.transpose(pred, (0, 3, 1, 2))
+
+ desired_weight_x = np.array([[[[0, 1.0, 0], [0, -2.0, 0], [0, 1.0, 0]]]])
+ weight_x = Tensor(desired_weight_x, dtype=mstype.float32)
+ desired_weight_y = np.array([[[[0, 0, 0], [1.0, -2.0, 1.0], [0, 0, 0]]]])
+ weight_y = Tensor(desired_weight_y, dtype=mstype.float32)
+ conv2d = ops.Conv2D(out_channel=1, kernel_size=3)
+ dtd2 = conv2d(pred_tensor, weight_x) + conv2d(pred_tensor, weight_y)
+
+ dtd2 = ops.transpose(dtd2, (0, 2, 3, 1))
+ dtd2 = dtd2 * (1 - mask_no_BCout)
+
+ np_array = np.zeros((batch_no, 126, 126, 1))
+ tensor_zeros = Tensor(np_array, dtype=mstype.float32)
+ l_phy = self.loss_fn(dtd2, tensor_zeros)
+
+ # Loss
+ # l_total = l_data
+ # l_total = l_phy
+ l_total = self.loss1_weight * l_data + l_phy
+
+ step_losses.append(l_total)
+ loss += l_total
+
+ return loss, l_data, l_phy, inputs, pred, labels, step_losses
+ #return loss, inputs, pred, labels, step_losses
+
+ def update_loss_lists(self, loss1, loss2):
+ # 更新损失列表
+ self.loss1_list.append(loss1.asnumpy())
+ self.loss2_list.append(loss2.asnumpy())
+ if len(self.loss1_list) > self.window_size:
+ # 保持损失列表的大小为窗口大小
+ self.loss1_list.pop(0)
+ self.loss2_list.pop(0)
+
+ def adjust_weights(self, epoch):
+ if len(self.loss1_list) >= self.window_size:
+ # 计算滑动窗口内的平均损失
+ sliding_avg_loss1 = np.mean(self.loss1_list[-self.window_size:])
+ sliding_avg_loss2 = np.mean(self.loss2_list[-self.window_size:])
+ else:
+ sliding_avg_loss1 = np.mean(self.loss1_list)
+ sliding_avg_loss2 = np.mean(self.loss2_list)
+
+ # 调整权重
+ if sliding_avg_loss2 > 0: # 防止除以零
+ current_ratio = sliding_avg_loss2 / (sliding_avg_loss1 * self.loss1_weight)
+ self.loss1_weight *= current_ratio
+ self.loss1_weight = min(self.loss1_weight, self.MAX_WEIGHT)
+
+ def _denormalize(self, x):
+ return x * self.std + self.mean
diff --git a/CombinedMethod/src/visual.py b/CombinedMethod/src/visual.py
new file mode 100644
index 0000000000000000000000000000000000000000..985831dd341d0fe82b722fc7cf4a0d91f88ee293
--- /dev/null
+++ b/CombinedMethod/src/visual.py
@@ -0,0 +1,86 @@
+# Copyright 2023 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""
+visual
+"""
+import os
+
+import numpy as np
+import matplotlib.pyplot as plt
+import matplotlib.gridspec as gridspec
+from mpl_toolkits.axes_grid1 import make_axes_locatable, axes_size
+
+
+def plt_log(predicts, labels, img_dir, epoch=0):
+ """plot log"""
+ plt.rcParams['figure.figsize'] = (12, 4.8)
+ for i in range(3):
+ label = labels[0, ..., i]
+ predict = predicts[0, ..., i]
+ prefixes = ["U", "V", "P"]
+ for prefix in prefixes:
+ t, _, _ = np.shape(label)
+
+ error = np.abs(label - predict)
+
+ vmin_u = label.min()
+ vmax_u = label.max()
+
+ vmin_error = error.min()
+ vmax_error = error.max()
+
+ vmin = [vmin_u, vmin_u, vmin_error]
+ vmax = [vmax_u, vmax_u, vmax_error]
+
+ t = len(label)
+ step = int(t // np.minimum(10, t))
+ times = [i * step for i in range(np.minimum(10, t))]
+
+ sub_titles = ["Label", "Predict", "Error"]
+ items = ["$T=%d$" % (t) for t in times]
+
+ label_2d = [label[t, ...] for t in times]
+ predict_2d = [predict[t, ...] for t in times]
+ error_2d = [error[t, ...] for t in times]
+
+ fig = plt.figure()
+ gs = gridspec.GridSpec(3, len(times))
+ gs_idx = int(0)
+
+ for j, data_2d in enumerate([label_2d, predict_2d, error_2d]):
+ for k, data in enumerate(data_2d):
+ ax = fig.add_subplot(gs[gs_idx])
+ gs_idx += 1
+
+ img = ax.imshow(data.T, vmin=vmin[j], vmax=vmax[j], cmap=plt.get_cmap("turbo"), origin='lower')
+
+ ax.set_title(sub_titles[j] + " " + items[k], fontsize=10)
+ plt.axis('off')
+
+ aspect = 20
+ pad_fraction = 0.5
+ divider = make_axes_locatable(ax)
+ width = axes_size.AxesY(ax, aspect=1 / aspect)
+ pad = axes_size.Fraction(pad_fraction, width)
+ cax = divider.append_axes("right", size=width, pad=pad)
+ cb = plt.colorbar(img, cax=cax)
+ cb.ax.tick_params(labelsize=6)
+
+ gs.tight_layout(fig, pad=0.2, w_pad=0.2, h_pad=0.2)
+
+ file_name = os.path.join(img_dir, prefix + "_epoch-%d_result.png" % epoch)
+ fig.savefig(file_name)
+
+ plt.close()
\ No newline at end of file