diff --git a/MindEarth/applications/earthquake/G-TEAM/G-TEAM.ipynb b/MindEarth/applications/earthquake/G-TEAM/G-TEAM.ipynb
new file mode 100644
index 0000000000000000000000000000000000000000..211fcfb2a527dede9e1ddee928cf4005200397fe
--- /dev/null
+++ b/MindEarth/applications/earthquake/G-TEAM/G-TEAM.ipynb
@@ -0,0 +1,339 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "id": "daaa37d8-af16-47a9-b512-b8aaf15f4fe3",
+ "metadata": {},
+ "source": [
+ "# G-TEAM地震预警模型\n",
+ "\n",
+ "## 概述\n",
+ "\n",
+ "地震预警系统旨在在破坏性震动到达前尽早发出警报,以减少人员伤亡和经济损失。G-TEAM 模型是一种数据驱动的全国地震预警系统,结合了图神经网络(GNN)和 Transformer 架构,能够在地震发生后 3 秒内迅速提供震中位置、震级及地震强度分布。该模型通过直接处理原始地震波形数据,避免了手动特征选择的限制,并充分利用多台站数据,提高了预测的准确性和实时性。\n",
+ "\n",
+ "本模型是一款高效的地震预警系统,结合了图神经网络(Graph Neural Network, GNN)与 Transformer 架构,以任意数量的地震台站记录的地震波形数据作为输入。该模型能够实时接收地震信号,并对震源位置、震级以及地震烈度分布范围进行快速且精准的估计,其中烈度分布范围以地面峰值加速度(Peak Ground Acceleration, PGA)表征。通过深度学习方法,本模型可以充分利用地震台网的空间关联性与时序特征,提高预警精度和响应速度,为地震应急响应和减灾决策提供可靠支持。\n",
+ "\n",
+ "\n",
+ "\n",
+ "该模型采用多源地震台站数据进行PGA预测,具体架构如下:首先,系统接收多个地震台站的位置信息及其记录的地震波形数据,同时获取待估计PGA的目标位置坐标。对于每个地震台站的波形数据,首先进行标准化处理,随后通过卷积神经网络(CNN)进行特征提取。提取的特征经全连接层进行特征融合,并与对应台站的位置信息共同构成特征向量。\n",
+ "目标PGA位置坐标经过位置编码模块处理后,形成特征向量。所有特征向量按序列形式输入到Transformer编码器中,编码器通过自注意力机制捕捉全局依赖关系。编码器输出依次通过三个独立的全连接层,分别完成地震事件震级、震中位置以及PGA的回归预测任务。\n",
+ "\n",
+ "本模型的训练数据来源于[谛听数据集2.0 -中国地震台网多功能大型人工智能训练数据集](http://www.esdc.ac.cn/article/137),该数据集汇集了中国大陆及其邻近地区(15°-50°N,65°-140°E)1177 个中国地震台网固定台站的波形记录,覆盖时间范围为 2020 年 3 月至 2023 年 2 月。数据集包含研究区域内所有震级大于 0 的地方震事件,共计 264,298 个。我们在训练过程中仅选取了初至 P 波和 S 波震相,并且只保留至少被三个台站记录到的地震事件,以确保数据的可靠性和稳定性。\n",
+ "\n",
+ "目前本模型已开源推理部分,可使用提供的ckpt进行推理。\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 19,
+ "id": "2a6dec0b-9307-4317-891a-c168fa400648",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import numpy as np\n",
+ "\n",
+ "import mindspore as ms\n",
+ "from mindspore import context"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "ced04ef1-9f41-429c-817d-9ed87ad67209",
+ "metadata": {},
+ "source": [
+ "下述src可在[GTEAM/src](./src)下载"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 13,
+ "id": "cb8b38ab-430d-48c7-9939-3dda1718802e",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from mindearth import load_yaml_config, make_dir\n",
+ "\n",
+ "from src.utils import (\n",
+ " predict_at_time,\n",
+ " calc_mag_stats,\n",
+ " calc_loc_stats,\n",
+ " calc_pga_stats,\n",
+ " init_model,\n",
+ " get_logger\n",
+ ")\n",
+ "from src.forcast import GTeamInference\n",
+ "from src.data import load_data\n",
+ "from src.visual import generate_true_pred_plot"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "40e22dca-f372-459b-97e9-2cf09a5ac412",
+ "metadata": {},
+ "source": [
+ "可以在[配置文件](./config/GTEAM.yaml)中配置模型、数据和优化器等参数。"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 14,
+ "id": "f4e613c6-95a6-4af4-9636-9f62df623b7c",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "config = load_yaml_config(\"./config/GTEAM.yaml\")\n",
+ "context.set_context(mode=ms.PYNATIVE_MODE)\n",
+ "ms.set_device(device_target=\"Ascend\", device_id=0)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 15,
+ "id": "6160aeba-7b83-4bbc-941d-b0a7ffe1e4cf",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "2025-04-05 08:57:36,391 - utils.py[line:179] - INFO: {'hidden_dim': 1000, 'hidden_dropout': 0.0, 'n_heads': 10, 'n_pga_targets': 15, 'output_location_dims': [150, 100, 50, 30, 3], 'output_mlp_dims': [150, 100, 50, 30, 1], 'transformer_layers': 6, 'waveform_model_dims': [500, 500, 500], 'wavelength': [[0.01, 15], [0.01, 15], [0.01, 10]], 'times': [5], 'run_with_less_data': False, 'pga': True, 'mode': 'test', 'no_event_token': False}\n",
+ "2025-04-05 08:57:36,391 - utils.py[line:179] - INFO: {'hidden_dim': 1000, 'hidden_dropout': 0.0, 'n_heads': 10, 'n_pga_targets': 15, 'output_location_dims': [150, 100, 50, 30, 3], 'output_mlp_dims': [150, 100, 50, 30, 1], 'transformer_layers': 6, 'waveform_model_dims': [500, 500, 500], 'wavelength': [[0.01, 15], [0.01, 15], [0.01, 10]], 'times': [5], 'run_with_less_data': False, 'pga': True, 'mode': 'test', 'no_event_token': False}\n",
+ "2025-04-05 08:57:36,392 - utils.py[line:179] - INFO: {'root_dir': './dataset', 'batch_size': 64, 'max_stations': 5, 'disable_station_foreshadowing': True, 'key': 'Mag', 'magnitude_resampling': 1, 'min_mag': 'None', 'min_upsample_magnitude': 4, 'aug_large': True, 'pga_from_inactive': True, 'pga_key': 'pga', 'pga_selection_skew': 1000, 'pos_offset': [30, 102], 'scale_metadata': False, 'selection_skew': 1000, 'shuffle_train_dev': True, 'transform_target_only': False, 'trigger_based': True, 'waveform_shape': [3000, 3], 'overwrite_sampling_rate': 'None', 'noise_seconds': 5}\n",
+ "2025-04-05 08:57:36,392 - utils.py[line:179] - INFO: {'root_dir': './dataset', 'batch_size': 64, 'max_stations': 5, 'disable_station_foreshadowing': True, 'key': 'Mag', 'magnitude_resampling': 1, 'min_mag': 'None', 'min_upsample_magnitude': 4, 'aug_large': True, 'pga_from_inactive': True, 'pga_key': 'pga', 'pga_selection_skew': 1000, 'pos_offset': [30, 102], 'scale_metadata': False, 'selection_skew': 1000, 'shuffle_train_dev': True, 'transform_target_only': False, 'trigger_based': True, 'waveform_shape': [3000, 3], 'overwrite_sampling_rate': 'None', 'noise_seconds': 5}\n",
+ "2025-04-05 08:57:36,394 - utils.py[line:179] - INFO: {'summary_dir': './summary', 'ckpt_path': '/home/lry/202542测试/PreDiff/G-TEAM/dataset/ckpt/g_team.ckpt'}\n",
+ "2025-04-05 08:57:36,394 - utils.py[line:179] - INFO: {'summary_dir': './summary', 'ckpt_path': '/home/lry/202542测试/PreDiff/G-TEAM/dataset/ckpt/g_team.ckpt'}\n"
+ ]
+ }
+ ],
+ "source": [
+ "save_dir = config[\"summary\"].get(\"summary_dir\", \"./summary\")\n",
+ "make_dir(save_dir)\n",
+ "logger_obj = get_logger(config)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "e39876ce-308a-4301-9182-0086423fe062",
+ "metadata": {},
+ "source": [
+ "## 初始化模型"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 16,
+ "id": "a408f4ce-92fd-401b-941e-d60e3aaa9044",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "model = init_model(config)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "11f3dcd6-ef19-461d-9db5-ab2ef2ef2ec9",
+ "metadata": {},
+ "source": [
+ "## 数据集准备\n",
+ "\n",
+ "根据地震后发生时间选择不同台站检测的数据"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "4fc7dc04-8b29-4f88-8a51-5dbd628e02c7",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "class GTeamInference:\n",
+ " \"\"\"\n",
+ " Initialize the GTeamInference class.\n",
+ " \"\"\"\n",
+ "\n",
+ " def __init__(self, model_ins, cfg, output_dir, logger):\n",
+ " \"\"\"\n",
+ " Args:\n",
+ " model_ins: The model instance used for inference.\n",
+ " cfg: Configuration dictionary containing model and data parameters.\n",
+ " output_dir: Directory to save the output results.\n",
+ " Attributes:\n",
+ " model: The model instance for inference.\n",
+ " cfg: Configuration dictionary.\n",
+ " output_dir: Directory to save outputs.\n",
+ " pga: Flag indicating if PGA (Peak Ground Acceleration) is enabled.\n",
+ " generator_params: Parameters for data generation.\n",
+ " model_params: Parameters specific to the model.\n",
+ " mag_key: Key for magnitude-related data.\n",
+ " pos_offset: Position offset for location predictions.\n",
+ " mag_stats: List to store magnitude prediction statistics.\n",
+ " loc_stats: List to store location prediction statistics.\n",
+ " pga_stats: List to store PGA prediction statistics.\n",
+ " \"\"\"\n",
+ " self.model = model_ins\n",
+ " self.cfg = cfg\n",
+ " self.output_dir = output_dir\n",
+ " self.logger = logger\n",
+ " self.pga = cfg[\"model\"].get(\"pga\", \"true\")\n",
+ " self.generator_params = cfg[\"data\"]\n",
+ " self.model_params = cfg[\"model\"]\n",
+ " self.output_dir = output_dir\n",
+ " self.mag_key = self.generator_params[\"key\"]\n",
+ " self.pos_offset = self.generator_params[\"pos_offset\"]\n",
+ " self.mag_stats = []\n",
+ " self.loc_stats = []\n",
+ " self.pga_stats = []\n",
+ "\n",
+ " def _parse_predictions(self, pred):\n",
+ " \"\"\"\n",
+ " Parse the raw predictions into magnitude, location, and PGA components.\n",
+ " \"\"\"\n",
+ " mag_pred = pred[0]\n",
+ " loc_pred = pred[1]\n",
+ " pga_pred = pred[2] if self.pga else []\n",
+ " return mag_pred, loc_pred, pga_pred\n",
+ "\n",
+ " def _process_predictions(\n",
+ " self, mag_pred, loc_pred, pga_pred, time, evt_metadata, pga_true\n",
+ " ):\n",
+ " \"\"\"\n",
+ " Process the parsed predictions to compute statistics and generate plots.\n",
+ " \"\"\"\n",
+ " mag_pred_np = [t[0].asnumpy() for t in mag_pred]\n",
+ " mag_pred_reshaped = np.concatenate(mag_pred_np, axis=0)\n",
+ "\n",
+ " loc_pred_np = [t[0].asnumpy() for t in loc_pred]\n",
+ " loc_pred_reshaped = np.array(loc_pred_np)\n",
+ "\n",
+ " pga_pred_np = [t.asnumpy() for t in pga_pred]\n",
+ " pga_pred_reshaped = np.concatenate(pga_pred_np, axis=0)\n",
+ " pga_true_reshaped = np.log(\n",
+ " np.abs(np.concatenate(pga_true, axis=0).reshape(-1, 1))\n",
+ " )\n",
+ "\n",
+ " if not self.model_params[\"no_event_token\"]:\n",
+ " self.mag_stats += calc_mag_stats(\n",
+ " mag_pred_reshaped, evt_metadata, self.mag_key\n",
+ " )\n",
+ "\n",
+ " self.loc_stats += calc_loc_stats(\n",
+ " loc_pred_reshaped, evt_metadata, self.pos_offset\n",
+ " )\n",
+ "\n",
+ " generate_true_pred_plot(\n",
+ " mag_pred_reshaped,\n",
+ " evt_metadata[self.mag_key].values,\n",
+ " time,\n",
+ " self.output_dir,\n",
+ " )\n",
+ " self.pga_stats = calc_pga_stats(pga_pred_reshaped, pga_true_reshaped)\n",
+ "\n",
+ " def _save_results(self):\n",
+ " \"\"\"\n",
+ " Save the final results (magnitude, location, and PGA statistics) to a JSON file.\n",
+ " \"\"\"\n",
+ " times = self.cfg[\"model\"].get(\"times\")\n",
+ " self.logger.info(\"times: {}\".format(times))\n",
+ " self.logger.info(\"mag_stats: {}\".format(self.mag_stats))\n",
+ " self.logger.info(\"loc_stats: {}\".format(self.loc_stats))\n",
+ " self.logger.info(\"pga_stats: {}\".format(self.pga_stats))\n",
+ "\n",
+ " def test(self):\n",
+ " \"\"\"\n",
+ " Perform inference for all specified times, process predictions, and save results.\n",
+ " This method iterates over the specified times, performs predictions, processes\n",
+ " the results, and saves the final statistics.\n",
+ " \"\"\"\n",
+ " data_data, evt_key, evt_metadata, meta_data, data_path = load_data(self.cfg)\n",
+ " pga_true = data_data[\"pga\"]\n",
+ " for time in self.cfg[\"model\"].get(\"times\"):\n",
+ " pred = predict_at_time(\n",
+ " self.model,\n",
+ " time,\n",
+ " data_data,\n",
+ " data_path,\n",
+ " evt_key,\n",
+ " evt_metadata,\n",
+ " config=self.cfg,\n",
+ " pga=self.pga,\n",
+ " sampling_rate=meta_data[\"sampling_rate\"],\n",
+ " )\n",
+ " mag_pred, loc_pred, pga_pred = self._parse_predictions(pred)\n",
+ " self._process_predictions(\n",
+ " mag_pred, loc_pred, pga_pred, time, evt_metadata, pga_true\n",
+ " )\n",
+ " self._save_results()\n",
+ " print(\"Inference completed and results saved\")\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "e45e95cf-a0e8-4ce6-a093-88cc4d871578",
+ "metadata": {},
+ "source": [
+ "## 开始推理"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 17,
+ "id": "bdac9a30-d8f9-4f16-850f-def6e23d46a2",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Data loaded from ./dataset/diting2_2020-2022_sc_abridged_test_filter_pga.pkl\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "2025-04-05 08:57:42,398 - forcast.py[line:115] - INFO: times: [5]\n",
+ "2025-04-05 08:57:42,398 - forcast.py[line:115] - INFO: times: [5]\n",
+ "2025-04-05 08:57:42,399 - forcast.py[line:116] - INFO: mag_stats: [-5.849881172180176, 0.26172267853934106, 0.2561628818511963]\n",
+ "2025-04-05 08:57:42,399 - forcast.py[line:116] - INFO: mag_stats: [-5.849881172180176, 0.26172267853934106, 0.2561628818511963]\n",
+ "2025-04-05 08:57:42,400 - forcast.py[line:117] - INFO: loc_stats: [5.55861115185705, 5.1707730693636345, 4.317579930843666, 4.128873124004999]\n",
+ "2025-04-05 08:57:42,400 - forcast.py[line:117] - INFO: loc_stats: [5.55861115185705, 5.1707730693636345, 4.317579930843666, 4.128873124004999]\n",
+ "2025-04-05 08:57:42,402 - forcast.py[line:118] - INFO: pga_stats: [0.8641006385570611, 0.4655571071890895, 0.28675066434439034]\n",
+ "2025-04-05 08:57:42,402 - forcast.py[line:118] - INFO: pga_stats: [0.8641006385570611, 0.4655571071890895, 0.28675066434439034]\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Inference completed and results saved\n"
+ ]
+ }
+ ],
+ "source": [
+ "processor = GTeamInference(model, config, save_dir, logger_obj)\n",
+ "processor.test()"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.11.11"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/MindEarth/applications/earthquake/G-TEAM/README.md b/MindEarth/applications/earthquake/G-TEAM/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..4ca498841c3069608ae82677c591abf77fcc6c21
--- /dev/null
+++ b/MindEarth/applications/earthquake/G-TEAM/README.md
@@ -0,0 +1,80 @@
+ENGLISH | [简体中文](README.md)
+
+# G-TEAM Earthquake Early Warning Model
+
+## Overview
+
+The earthquake early warning system aims to issue alerts before destructive seismic waves arrive, thereby reducing casualties and economic losses. The G-TEAM model is a data-driven national earthquake early warning system that integrates Graph Neural Networks (GNN) and Transformer architectures. It can rapidly estimate epicenter location, magnitude, and seismic intensity distribution within 3 seconds after earthquake occurrence. By directly processing raw seismic waveform data, the model eliminates limitations from manual feature selection and enhances prediction accuracy and real-time performance through multi-station data utilization.
+
+This model is an efficient earthquake early warning system combining Graph Neural Networks (GNN) and Transformer architectures, taking seismic waveform data from any number of seismic stations as input. It enables real-time processing of seismic signals to deliver fast and precise estimations of hypocenter location, magnitude, and seismic intensity distribution range (characterized by Peak Ground Acceleration, PGA). Leveraging deep learning methods, the model fully exploits spatial correlations and temporal features within seismic networks to improve warning accuracy and response speed, providing robust support for earthquake emergency response and disaster mitigation strategies.
+
+
+
+The PGA prediction architecture using multi-source seismic station data operates as follows:
+
+1. The system receives position data and waveform recordings from multiple seismic stations, along with target coordinates for PGA estimation.
+2. For each station's waveform data:
+ - Perform standardization
+ - Extract features via Convolutional Neural Networks (CNN)
+ - Fuse features through fully connected layers
+ - Combine with station coordinates to form feature vectors
+3. Target PGA coordinates are processed through positional encoding to generate feature vectors.
+4. All feature vectors are sequentially fed into a Transformer encoder that captures global dependencies via self-attention mechanisms.
+5. Encoder outputs pass through three independent fully connected layers to perform regression tasks: magnitude estimation, epicenter localization, and PGA prediction.
+
+## Training Data
+
+The model is trained using the [Diting Dataset 2.0 - Multifunctional Large AI Training Dataset for China Seismic Network](http://www.esdc.ac.cn/article/137), which contains:
+
+- Waveform records from 1,177 fixed stations in China (15°-50°N, 65°-140°E)
+- Data coverage: March 2020 to February 2023
+- 264,298 local seismic events (M > 0)
+- Only retains initial P-wave and S-wave phases
+- Includes events recorded by ≥3 stations for reliability
+
+The inference module has been open-sourced and supports prediction using provided checkpoint files (.ckpt).
+
+## Quick Start
+
+You can download the required data and ckpt files for training and inference at [dataset](https://download-mindspore.osinfra.cn/mindscience/mindearth/dataset/G-TEAM/)
+
+### Execution
+
+Run via command line using the `main` script:
+
+```python
+python main.py --cfg_path ./config/config.yaml --device_id 0 --device_target Ascend
+
+```
+
+Parameters:
+--cfg_path: Configuration file path (default: "./config/config.yaml")
+--device_target: Hardware type (default: Ascend)
+--device_id: Device ID (default: 0)
+
+### Visualization
+
+
+
+Scatter plot compares predicted vs actual PGA values (x-axis vs y-axis). Closer alignment to y=x line indicates higher accuracy.
+
+### 结果展示
+
+| Parameter | NPU |
+|:----------------------:|:--------------------------:|
+| Hardware | Ascend, memory 64G |
+| MindSpore Version | mindspore2.5.0 |
+| Dataset | diting2_2020-2022_sc |
+| Test Parameters | batch_size=1
steps=9 |
+| Magnitude Error (RMSE, MSE) | [ 0.262, 0.257 ] |
+| Epicenter Distance Error (RMSE, MAE) | [ 4.318 , 4.123 ] |
+| Hypocenter Depth Error (RMSE, MAE) | [ 5.559 , 5.171 ] |
+| PGA Error (RMSE, MSE) |[ 0.466, 0.287 ] |
+| Inference Resource | 1NPU |
+| Inference Speed(ms/step) | 556 |
+
+## Contributors
+
+gitee id: funfunplus
+
+email: funniless@163.com
\ No newline at end of file
diff --git a/MindEarth/applications/earthquake/G-TEAM/README_CN.md b/MindEarth/applications/earthquake/G-TEAM/README_CN.md
new file mode 100644
index 0000000000000000000000000000000000000000..4fc45eae12546615ad6987b7c593bfe9d3949efb
--- /dev/null
+++ b/MindEarth/applications/earthquake/G-TEAM/README_CN.md
@@ -0,0 +1,61 @@
+[ENGLISH](README.md) | 简体中文
+
+# G-TEAM地震预警模型
+
+## 概述
+
+地震预警系统旨在在破坏性震动到达前尽早发出警报,以减少人员伤亡和经济损失。G-TEAM 模型是一种数据驱动的全国地震预警系统,结合了图神经网络(GNN)和 Transformer 架构,能够在地震发生后 3 秒内迅速提供震中位置、震级及地震强度分布。该模型通过直接处理原始地震波形数据,避免了手动特征选择的限制,并充分利用多台站数据,提高了预测的准确性和实时性。
+
+本模型是一款高效的地震预警系统,结合了图神经网络(Graph Neural Network, GNN)与 Transformer 架构,以任意数量的地震台站记录的地震波形数据作为输入。该模型能够实时接收地震信号,并对震源位置、震级以及地震烈度分布范围进行快速且精准的估计,其中烈度分布范围以地面峰值加速度(Peak Ground Acceleration, PGA)表征。通过深度学习方法,本模型可以充分利用地震台网的空间关联性与时序特征,提高预警精度和响应速度,为地震应急响应和减灾决策提供可靠支持。
+
+
+
+该模型采用多源地震台站数据进行PGA预测,具体架构如下:首先,系统接收多个地震台站的位置信息及其记录的地震波形数据,同时获取待估计PGA的目标位置坐标。对于每个地震台站的波形数据,首先进行标准化处理,随后通过卷积神经网络(CNN)进行特征提取。提取的特征经全连接层进行特征融合,并与对应台站的位置信息共同构成特征向量。
+目标PGA位置坐标经过位置编码模块处理后,形成特征向量。所有特征向量按序列形式输入到Transformer编码器中,编码器通过自注意力机制捕捉全局依赖关系。编码器输出依次通过三个独立的全连接层,分别完成地震事件震级、震中位置以及PGA的回归预测任务。
+
+本模型的训练数据来源于[谛听数据集2.0 -中国地震台网多功能大型人工智能训练数据集](http://www.esdc.ac.cn/article/137),该数据集汇集了中国大陆及其邻近地区(15°-50°N,65°-140°E)1177 个中国地震台网固定台站的波形记录,覆盖时间范围为 2020 年 3 月至 2023 年 2 月。数据集包含研究区域内所有震级大于 0 的地方震事件,共计 264,298 个。我们在训练过程中仅选取了初至 P 波和 S 波震相,并且只保留至少被三个台站记录到的地震事件,以确保数据的可靠性和稳定性。
+
+目前本模型已开源推理部分,可使用提供的[ckpt](https://download-mindspore.osinfra.cn/mindscience/mindearth/dataset/G-TEAM/)进行推理。
+
+## 快速开始
+
+可在[dataset](https://download-mindspore.osinfra.cn/mindscience/mindearth/dataset/G-TEAM/)下载训练所需要的数据集。
+
+### 运行方式: 在命令行调用`main`脚本
+
+### 推理
+
+```python
+
+python main.py --cfg_path ./config/config.yaml --device_id 0 --device_target Ascend
+
+```
+
+其中, --cfg_path表示配置文件路径,默认值"./config/config.yaml" --device_target 表示设备类型,默认Ascend。 --device_id 表示运行设备的编号,默认值0。
+
+### 结果可视化
+
+
+
+图示为pga的点坐标,横轴表示预测值,纵轴表示实际值,点数据越靠近y=x这条直线代表数据越准确。
+
+### 结果展示
+
+| 参数 | NPU |
+|:----------------------:|:--------------------------:|
+| 硬件 | Ascend, memory 64G |
+| mindspore版本 | mindspore2.5.0 |
+| 数据集 | diting2_2020-2022_sc |
+| 测试参数 | batch_size=1
steps=9 |
+| Mag震级误差(RMSE, MSE) | [ 0.262, 0.257 ] |
+| Loc震中距离误差(RMSE, MAE) | [ 4.318 , 4.123 ] |
+| Loc震源距离误差(RMSE, MAE) | [ 5.559 , 5.171 ] |
+| Pga峰值地面加速度误差(RMSE, MSE) |[ 0.466, 0.287 ] |
+| 推理资源 | 1NPU |
+| 推理速度(ms/step) | 556 |
+
+## 贡献者
+
+gitee id: funfunplus
+
+email: funniless@163.com
\ No newline at end of file
diff --git a/MindEarth/applications/earthquake/G-TEAM/config/GTEAM.yaml b/MindEarth/applications/earthquake/G-TEAM/config/GTEAM.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..68c066929f8c5503703691f7f9494f86acc29fa6
--- /dev/null
+++ b/MindEarth/applications/earthquake/G-TEAM/config/GTEAM.yaml
@@ -0,0 +1,40 @@
+model:
+ hidden_dim: 1000
+ hidden_dropout: 0.0
+ n_heads: 10
+ n_pga_targets: 15
+ output_location_dims: [150,100,50,30,3]
+ output_mlp_dims: [150,100,50,30,1]
+ transformer_layers: 6
+ waveform_model_dims: [500,500,500]
+ wavelength: [[0.01,15],[0.01,15],[0.01,10]]
+ times: [5]
+ run_with_less_data: false
+ pga: true
+ mode: test
+ no_event_token : False
+data:
+ root_dir: "./dataset"
+ batch_size: 64
+ max_stations: 5
+ disable_station_foreshadowing: true
+ key: Mag
+ magnitude_resampling: 1
+ min_mag: None
+ min_upsample_magnitude: 4
+ aug_large: True
+ pga_from_inactive: true
+ pga_key: pga
+ pga_selection_skew: 1000
+ pos_offset: [30,102]
+ scale_metadata: false
+ selection_skew: 1000
+ shuffle_train_dev: true
+ transform_target_only: false
+ trigger_based: true
+ waveform_shape: [3000, 3]
+ overwrite_sampling_rate: None
+ noise_seconds: 5
+summary:
+ summary_dir: "./summary"
+ ckpt_path: "./dataset/ckpt/g_team.ckpt"
diff --git a/MindEarth/applications/earthquake/G-TEAM/images/image.png b/MindEarth/applications/earthquake/G-TEAM/images/image.png
new file mode 100644
index 0000000000000000000000000000000000000000..7c638885b15de261c21b664bd576b55ec169dd99
Binary files /dev/null and b/MindEarth/applications/earthquake/G-TEAM/images/image.png differ
diff --git a/MindEarth/applications/earthquake/G-TEAM/images/pga.png b/MindEarth/applications/earthquake/G-TEAM/images/pga.png
new file mode 100644
index 0000000000000000000000000000000000000000..b9b8fd6b2b9bd6ae1cc363f6ca4f5310287fe2b0
Binary files /dev/null and b/MindEarth/applications/earthquake/G-TEAM/images/pga.png differ
diff --git a/MindEarth/applications/earthquake/G-TEAM/main.py b/MindEarth/applications/earthquake/G-TEAM/main.py
new file mode 100644
index 0000000000000000000000000000000000000000..d8237f8eaa395f327859e53cb21566320c7a8ad6
--- /dev/null
+++ b/MindEarth/applications/earthquake/G-TEAM/main.py
@@ -0,0 +1,51 @@
+# Copyright 2023 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"main function"
+import argparse
+
+import mindspore as ms
+from mindspore import context
+from mindearth import load_yaml_config, make_dir
+
+from src.utils import init_model, get_logger
+from src.forcast import GTeamInference
+
+
+def get_args():
+ """get args"""
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--cfg_path", default="./config/GTEAM.yaml", type=str)
+ parser.add_argument("--device_id", default=0, type=int)
+ parser.add_argument("--device_target", default="Ascend", type=str)
+ parse_args = parser.parse_args()
+ return parse_args
+
+
+def test(cfg):
+ """main test"""
+ save_dir = cfg["summary"].get("summary_dir", "./summary")
+ make_dir(save_dir)
+ model = init_model(cfg)
+ logger_obj = get_logger(cfg)
+ processor = GTeamInference(model, cfg, save_dir, logger_obj)
+ processor.test()
+
+
+if __name__ == "__main__":
+ args = get_args()
+ config = load_yaml_config(args.cfg_path)
+ context.set_context(mode=ms.PYNATIVE_MODE)
+ ms.set_device(device_target=args.device_target, device_id=args.device_id)
+ test(config)
diff --git a/MindEarth/applications/earthquake/G-TEAM/src/data.py b/MindEarth/applications/earthquake/G-TEAM/src/data.py
new file mode 100644
index 0000000000000000000000000000000000000000..80ee3066718e010224d3117064464defccb3dd46
--- /dev/null
+++ b/MindEarth/applications/earthquake/G-TEAM/src/data.py
@@ -0,0 +1,799 @@
+# Copyright 2023 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"load diting data"
+import os
+import pickle
+import glob
+import h5py
+import numpy as np
+
+import mindspore
+from mindspore.dataset import Dataset
+
+# degrees to kilometers
+D2KM = 111.19492664455874
+
+
+def load_pickle_data(filename):
+ """Load and deserialize data from a pickle file."""
+ with open(filename, "rb") as file:
+ data = pickle.load(file)
+ print(f"Data loaded from {filename}")
+ return data
+
+
+def save_pickle_data(filename, data):
+ """Serialize and save data to a pickle file."""
+ with open(filename, "wb") as file:
+ pickle.dump(data, file)
+ print(f"Data saved to {filename}")
+
+
+def load_data(cfg):
+ """Load preprocessed seismic data from a configured pickle file."""
+ data_path = glob.glob(os.path.join(cfg["data"].get("root_dir"), "*.hdf5"))[0]
+ file_basename = os.path.basename(data_path).split(".")[0]
+ filename = os.path.join(
+ cfg["data"].get("root_dir"), f"{file_basename}_test_filter_pga.pkl"
+ )
+ loaded_pickle_data = load_pickle_data(filename)
+ _, evt_metadata, meta_data, data_data, evt_key, _ = loaded_pickle_data
+ return data_data, evt_key, evt_metadata, meta_data, data_path
+
+
+def detect_location_keys(columns):
+ """Identify standardized location keys from column headers."""
+ candidates = [
+ ["LAT", "Latitude(°)", "Latitude"],
+ ["LON", "Longitude(°)", "Longitude"],
+ ["DEPTH", "JMA_Depth(km)", "Depth(km)", "Depth/Km"],
+ ]
+
+ coord_keys = []
+ for keyset in candidates:
+ for key in keyset:
+ if key in columns:
+ coord_keys += [key]
+ break
+
+ if len(coord_keys) != len(candidates):
+ raise ValueError("Unknown location key format")
+
+ return coord_keys
+
+
+class EarthquakeDataset(Dataset):
+ """
+ Dataset class for loading and processing seismic event data.
+ Handles waveform loading, magnitude-based resampling, PGA target processing,
+ and batch preparation for earthquake analysis models.
+ Key Features:
+ Batch processing of seismic waveforms and metadata
+ Magnitude-based data resampling for class balance
+ PGA (Peak Ground Acceleration) target handling
+ HDF5 waveform data loading
+ Flexible data shuffling and oversampling
+ """
+
+ def __init__(
+ self,
+ data_path,
+ event_key,
+ data,
+ event_metadata,
+ batch_size=32,
+ shuffle=True,
+ oversample=1,
+ magnitude_resampling=3,
+ min_upsample_magnitude=2,
+ pga_targets=None,
+ pga_mode=False,
+ pga_key="pga",
+ coord_keys=None,
+ **kwargs,
+ ):
+
+ super(EarthquakeDataset, self).__init__()
+
+ self.data_path = data_path
+ self.event_key = event_key
+ self.batch_size = batch_size
+ self.shuffle = shuffle
+ self.metadata = data["coords"]
+ self.event_metadata = event_metadata
+ self.pga = data[pga_key]
+ self.triggers = data["p_picks"]
+ self.oversample = oversample
+
+ self.pga_mode = pga_mode
+ self.pga_targets = pga_targets
+
+ self.base_indexes = np.arange(self.event_metadata.shape[0])
+ self.reverse_index = None
+
+ if magnitude_resampling > 1:
+ magnitude = self.event_metadata[kwargs["key"]].values
+ for i in np.arange(min_upsample_magnitude, 9):
+ ind = np.where(np.logical_and(i < magnitude, magnitude <= i + 1))[0]
+ self.base_indexes = np.concatenate(
+ (
+ self.base_indexes,
+ np.repeat(ind, int(magnitude_resampling ** (i - 1) - 1)),
+ )
+ )
+
+ if pga_mode and pga_targets is not None:
+ new_base_indexes = []
+ self.reverse_index = []
+ c = 0
+ for idx in self.base_indexes:
+ num_samples = (len(self.pga[idx]) - 1) // pga_targets + 1
+ new_base_indexes += [(idx, i) for i in range(num_samples)]
+ self.reverse_index += [c]
+ c += num_samples
+ self.reverse_index += [c]
+ self.base_indexes = new_base_indexes
+ if coord_keys is None:
+ self.coord_keys = detect_location_keys(event_metadata.columns)
+ else:
+ self.coord_keys = coord_keys
+ self.use_shuffle()
+
+ def __len__(self):
+ """get length"""
+ return int(np.ceil(len(self.indexes) / self.batch_size))
+
+ def __getitem__(self, index):
+ """Load data."""
+ batch_indexes = self.indexes[
+ index * self.batch_size : (index + 1) * self.batch_size
+ ]
+ batch_data = {
+ "indexes": batch_indexes,
+ "waveforms": [],
+ "metadata": [],
+ "pga": [],
+ "p_picks": [],
+ "event_info": [],
+ }
+ if self.pga_mode:
+ batch_data["pga_indexes"] = [x[1] for x in batch_indexes]
+ batch_data["indexes"] = [x[0] for x in batch_indexes]
+ for idx in batch_data["indexes"]:
+ event = self.event_metadata.iloc[idx]
+ event_name = str(event[self.event_key])
+ waveform_data = self._load_waveform_data(event_name)
+ batch_data["waveforms"].append(waveform_data)
+ batch_data["metadata"].append(self.metadata[idx])
+ batch_data["pga"].append(self.pga[idx])
+ batch_data["p_picks"].append(self.triggers[idx])
+ batch_data["event_info"].append(event)
+
+ return batch_data
+
+ def _load_waveform_data(self, event_name):
+ """load waveform data"""
+ with h5py.File(self.data_path, "r") as f:
+ if "data" not in f or event_name not in f["data"]:
+ return None
+ g_event = f["data"][event_name]
+ if "waveforms" not in g_event:
+ return None
+ return g_event["waveforms"][:, :, :]
+
+ def use_shuffle(self):
+ """shuffle index"""
+ self.indexes = np.repeat(self.base_indexes.copy(), self.oversample, axis=0)
+ if self.shuffle:
+ np.random.shuffle(self.indexes)
+
+
+class DataProcessor:
+ """
+ A data processor for seismic event analysis that handles waveform preprocessing,
+ station selection, and target preparation for machine learning models.
+ Key functionalities:
+ Batch processing of seismic waveforms and metadata
+ Station selection strategies for efficient processing
+ Multiple preprocessing techniques (cutout, integration, etc.)
+ Coordinate transformations and target preparations
+ PGA (Peak Ground Acceleration) target handling
+ Data augmentation techniques (label smoothing, station blinding)
+ """
+
+ def __init__(
+ self,
+ waveform_shape=(3000, 6),
+ max_stations=None,
+ cutout=None,
+ sliding_window=False,
+ windowlen=3000,
+ coords_target=True,
+ pos_offset=(-21, -69),
+ label_smoothing=False,
+ station_blinding=False,
+ pga_targets=None,
+ adjust_mean=True,
+ transform_target_only=False,
+ trigger_based=None,
+ disable_station_foreshadowing=False,
+ selection_skew=None,
+ pga_from_inactive=False,
+ integrate=False,
+ sampling_rate=100.0,
+ select_first=False,
+ scale_metadata=True,
+ p_pick_limit=5000,
+ pga_mode=False,
+ no_event_token=False,
+ pga_selection_skew=None,
+ **kwargs,
+ ):
+ self.waveform_shape = waveform_shape
+ self.max_stations = max_stations
+ self.cutout = cutout
+ self.sliding_window = sliding_window
+ self.windowlen = windowlen
+ self.coords_target = coords_target
+ self.pos_offset = pos_offset
+ self.label_smoothing = label_smoothing
+ self.station_blinding = station_blinding
+ self.pga_targets = pga_targets
+ self.adjust_mean = adjust_mean
+ self.transform_target_only = transform_target_only
+ self.trigger_based = trigger_based
+ self.disable_station_foreshadowing = disable_station_foreshadowing
+ self.selection_skew = selection_skew
+ self.pga_from_inactive = pga_from_inactive
+ self.integrate = integrate
+ self.sampling_rate = sampling_rate
+ self.select_first = select_first
+ self.scale_metadata = scale_metadata
+ self.p_pick_limit = p_pick_limit
+ self.pga_mode = pga_mode
+ self.no_event_token = no_event_token
+ self.pga_selection_skew = pga_selection_skew
+ self.key = kwargs["key"]
+
+ def process_batch(self, batch_data):
+ """Main method to process a batch of data, now decomposed into smaller functions."""
+ (
+ indexes,
+ waveforms_list,
+ metadata_list,
+ pga_list,
+ p_picks_list,
+ event_info_list,
+ pga_indexes,
+ ) = self._extract_batch_data(batch_data)
+
+ true_batch_size = len(indexes)
+ true_max_stations_in_batch = self._get_max_stations_in_batch(metadata_list)
+ waveforms, metadata, pga, full_p_picks, p_picks, reverse_selections = (
+ self._initialize_arrays(
+ true_batch_size, true_max_stations_in_batch, metadata_list
+ )
+ )
+ waveforms, metadata, pga, p_picks, full_p_picks, reverse_selections = (
+ self._process_stations(
+ waveforms_list,
+ metadata_list,
+ pga_list,
+ p_picks_list,
+ waveforms,
+ metadata,
+ pga,
+ p_picks,
+ full_p_picks,
+ )
+ )
+ magnitude, target = self._process_magnitude_and_targets(event_info_list)
+ org_waveform_length = waveforms.shape[2]
+ waveforms, _ = self._process_waveforms(waveforms, org_waveform_length, p_picks)
+ metadata, target = self._transform_locations(metadata, target)
+ magnitude = self._apply_label_smoothing(magnitude)
+ metadata, pga = self._adjust_metadata_and_pga(metadata, pga)
+ pga_values, pga_targets_data = self._process_pga_targets(
+ true_batch_size,
+ pga,
+ metadata,
+ pga_indexes,
+ reverse_selections,
+ full_p_picks,
+ indexes,
+ )
+ waveforms, metadata = self._apply_station_blinding(waveforms, metadata)
+ waveforms, metadata = self._handle_stations_without_trigger(waveforms, metadata)
+ waveforms, metadata = self._ensure_no_empty_arrays(waveforms, metadata)
+ inputs, outputs = self._prepare_model_io(
+ waveforms, metadata, magnitude, target, pga_targets_data, pga_values
+ )
+
+ return inputs, outputs
+
+ def _extract_batch_data(self, batch_data):
+ """Extract data from the batch dictionary."""
+ indexes = batch_data["indexes"]
+ waveforms_list = batch_data["waveforms"]
+ metadata_list = batch_data["metadata"]
+ pga_list = batch_data["pga"]
+ p_picks_list = batch_data["p_picks"]
+ event_info_list = batch_data["event_info"]
+ pga_indexes = batch_data.get("pga_indexes", None)
+
+ return (
+ indexes,
+ waveforms_list,
+ metadata_list,
+ pga_list,
+ p_picks_list,
+ event_info_list,
+ pga_indexes,
+ )
+
+ def _get_max_stations_in_batch(self, metadata_list):
+ """Calculate the maximum number of stations in the batch."""
+ return max(
+ [len(m) for m in metadata_list if m is not None] + [self.max_stations]
+ )
+
+ def _initialize_arrays(
+ self, true_batch_size, true_max_stations_in_batch, metadata_list
+ ):
+ """Initialize arrays for batch processing."""
+ waveforms = np.zeros([true_batch_size, self.max_stations] + self.waveform_shape)
+ metadata = np.zeros(
+ (true_batch_size, true_max_stations_in_batch) + metadata_list[0].shape[1:]
+ )
+ pga = np.zeros((true_batch_size, true_max_stations_in_batch))
+ full_p_picks = np.zeros((true_batch_size, true_max_stations_in_batch))
+ p_picks = np.zeros((true_batch_size, self.max_stations))
+ reverse_selections = []
+
+ return waveforms, metadata, pga, full_p_picks, p_picks, reverse_selections
+
+ def _process_stations(
+ self,
+ waveforms_list,
+ metadata_list,
+ pga_list,
+ p_picks_list,
+ waveforms,
+ metadata,
+ pga,
+ p_picks,
+ full_p_picks,
+ ):
+ """Process stations and waveforms for each item in the batch."""
+ reverse_selections = []
+
+ for i, (waveform_data, meta, pga_data, p_pick_data) in enumerate(
+ zip(waveforms_list, metadata_list, pga_list, p_picks_list)
+ ):
+ if waveform_data is None:
+ continue
+
+ num_stations = waveform_data.shape[0]
+
+ if num_stations <= self.max_stations:
+ waveforms[i, :num_stations] = waveform_data
+ metadata[i, : len(meta)] = meta
+ pga[i, : len(pga_data)] = pga_data
+ p_picks[i, : len(p_pick_data)] = p_pick_data
+ reverse_selections += [[]]
+ else:
+ selection = self._select_stations(num_stations, p_pick_data)
+
+ metadata[i, : len(selection)] = meta[selection]
+ pga[i, : len(selection)] = pga_data[selection]
+ full_p_picks[i, : len(selection)] = p_pick_data[selection]
+
+ tmp_reverse_selection = [0 for _ in selection]
+ for j, s in enumerate(selection):
+ tmp_reverse_selection[s] = j
+ reverse_selections += [tmp_reverse_selection]
+
+ selection = selection[: self.max_stations]
+ waveforms[i] = waveform_data[selection]
+ p_picks[i] = p_pick_data[selection]
+
+ return waveforms, metadata, pga, p_picks, full_p_picks, reverse_selections
+
+ def _select_stations(self, num_stations, p_pick_data):
+ """Select stations based on configured strategy."""
+ if self.selection_skew is None:
+ selection = np.arange(0, num_stations)
+ np.random.shuffle(selection)
+ else:
+ tmp_p_picks = p_pick_data.copy()
+ mask = np.logical_and(tmp_p_picks <= 0, tmp_p_picks > self.p_pick_limit)
+ tmp_p_picks[mask] = min(np.max(tmp_p_picks), self.p_pick_limit)
+ coeffs = np.exp(-tmp_p_picks / self.selection_skew)
+ coeffs *= np.random.random(coeffs.shape)
+ coeffs[p_pick_data == 0] = 0
+ coeffs[p_pick_data > self.waveform_shape[0]] = 0
+ selection = np.argsort(-coeffs)
+
+ if self.select_first:
+ selection = np.argsort(p_pick_data)
+
+ return selection
+
+ def _process_magnitude_and_targets(self, event_info_list):
+ """Process magnitude and coordinate targets."""
+ magnitude = np.array([e[self.key] for e in event_info_list], dtype=np.float32)
+ target = None
+
+ if self.coords_target:
+ coord_keys = detect_location_keys(
+ [col for e in event_info_list for col in e.index]
+ )
+ target = np.array(
+ [[e[k] for k in coord_keys] for e in event_info_list], dtype=np.float32
+ )
+
+ magnitude = np.expand_dims(np.expand_dims(magnitude, axis=-1), axis=-1)
+ return magnitude, target
+
+ def _process_waveforms(self, waveforms, org_waveform_length, p_picks):
+ """Apply cutout, sliding window, trigger-based, and integration transformations to waveforms."""
+ cutout = org_waveform_length
+
+ if self.cutout:
+ if self.sliding_window:
+ windowlen = self.windowlen
+ window_end = np.random.randint(
+ max(windowlen, self.cutout[0]),
+ min(waveforms.shape[2], self.cutout[1]) + 1,
+ )
+ waveforms = waveforms[:, :, window_end - windowlen : window_end]
+
+ cutout = window_end
+ if self.adjust_mean:
+ waveforms -= np.mean(waveforms, axis=2, keepdims=True)
+ else:
+ cutout = np.random.randint(*self.cutout)
+ if self.adjust_mean:
+ waveforms -= np.mean(
+ waveforms[:, :, : cutout + 1], axis=2, keepdims=True
+ )
+ waveforms[:, :, cutout:] = 0
+
+ if self.trigger_based:
+ p_picks[p_picks <= 0] = org_waveform_length
+ waveforms[cutout < p_picks, :, :] = 0
+
+ if self.integrate:
+ waveforms = np.cumsum(waveforms, axis=2) / self.sampling_rate
+
+ return waveforms, cutout
+
+ def _transform_locations(self, metadata, target):
+ """Transform locations using the location_transformation method."""
+ if self.coords_target:
+ metadata, target = self.location_transformation(metadata, target)
+ else:
+ metadata = self.location_transformation(metadata)
+ return metadata, target
+
+ def _apply_label_smoothing(self, magnitude):
+ """Apply label smoothing to magnitude if enabled."""
+ if self.label_smoothing:
+ magnitude += (
+ (magnitude > 4)
+ * np.random.randn(magnitude.shape[0]).reshape(magnitude.shape)
+ * (magnitude - 4)
+ * 0.05
+ )
+ return magnitude
+
+ def _adjust_metadata_and_pga(self, metadata, pga):
+ """Adjust metadata and PGA arrays based on configuration."""
+ if not self.pga_from_inactive and not self.pga_mode:
+ metadata = metadata[:, : self.max_stations]
+ pga = pga[:, : self.max_stations]
+ return metadata, pga
+
+ def _process_pga_targets(
+ self,
+ true_batch_size,
+ pga,
+ metadata,
+ pga_indexes,
+ reverse_selections,
+ full_p_picks,
+ indexes,
+ ):
+ """Process PGA targets if enabled."""
+ pga_values = None
+ pga_targets_data = None
+
+ if self.pga_targets:
+ pga_values = np.zeros((true_batch_size, self.pga_targets))
+ pga_targets_data = np.zeros((true_batch_size, self.pga_targets, 3))
+
+ if self.pga_mode and pga_indexes is not None:
+ self._process_pga_mode(
+ pga_values,
+ pga_targets_data,
+ pga,
+ metadata,
+ pga_indexes,
+ reverse_selections,
+ )
+ else:
+ self._process_pga_normal(
+ pga_values, pga_targets_data, pga, metadata, full_p_picks, indexes
+ )
+
+ pga_values = pga_values.reshape((true_batch_size, self.pga_targets, 1, 1))
+
+ return pga_values, pga_targets_data
+
+ def _process_pga_mode(
+ self,
+ pga_values,
+ pga_targets_data,
+ pga,
+ metadata,
+ pga_indexes,
+ reverse_selections,
+ ):
+ """Process PGA in PGA mode."""
+ for i in range(len(pga_values)):
+ pga_index = pga_indexes[i]
+ if reverse_selections[i]:
+ sorted_pga = pga[i, reverse_selections[i]]
+ sorted_metadata = metadata[i, reverse_selections[i]]
+ else:
+ sorted_pga = pga[i]
+ sorted_metadata = metadata[i]
+ pga_values_pre = sorted_pga[
+ pga_index * self.pga_targets : (pga_index + 1) * self.pga_targets
+ ]
+ pga_values[i, : len(pga_values_pre)] = pga_values_pre
+ pga_targets_pre = sorted_metadata[
+ pga_index * self.pga_targets : (pga_index + 1) * self.pga_targets,
+ :,
+ ]
+ if pga_targets_pre.shape[-1] == 4:
+ pga_targets_pre = pga_targets_pre[:, (0, 1, 3)]
+ pga_targets_data[i, : len(pga_targets_pre), :] = pga_targets_pre
+
+ def _process_pga_normal(
+ self, pga_values, pga_targets_data, pga, metadata, full_p_picks, indexes
+ ):
+ """Process PGA in normal mode."""
+ pga[np.logical_or(np.isnan(pga), np.isinf(pga))] = 0
+ for i in range(pga_values.shape[0]):
+ active = np.where(pga[i] != 0)[0]
+ if not active:
+ raise ValueError(f"Found event without PGA idx={indexes[i]}")
+ while len(active) < self.pga_targets:
+ active = np.repeat(active, 2)
+
+ if self.pga_selection_skew is not None:
+ active = self._select_pga_with_skew(active, full_p_picks[i])
+ else:
+ np.random.shuffle(active)
+
+ samples = active[: self.pga_targets]
+ if metadata.shape[-1] == 3:
+ pga_targets_data[i] = metadata[i, samples, :]
+ else:
+ full_targets = metadata[i, samples]
+ pga_targets_data[i] = full_targets[:, (0, 1, 3)]
+ pga_values[i] = pga[i, samples]
+
+ def _select_pga_with_skew(self, active, full_p_picks):
+ """Select PGA with skew-based selection."""
+ active_p_picks = full_p_picks[active]
+ mask = np.logical_and(active_p_picks <= 0, active_p_picks > self.p_pick_limit)
+ active_p_picks[mask] = min(np.max(active_p_picks), self.p_pick_limit)
+ coeffs = np.exp(-active_p_picks / self.pga_selection_skew)
+ coeffs *= np.random.random(coeffs.shape)
+ return active[np.argsort(-coeffs)]
+
+ def _apply_station_blinding(self, waveforms, metadata):
+ """Apply station blinding if enabled."""
+ if self.station_blinding:
+ mask = np.zeros(waveforms.shape[:2], dtype=bool)
+
+ for i in range(waveforms.shape[0]):
+ active = np.where((waveforms[i] != 0).any(axis=(1, 2)))[0]
+ if not active == 0:
+ active = np.zeros(1, dtype=int)
+ blind_length = np.random.randint(0, len(active))
+ np.random.shuffle(active)
+ blind = active[:blind_length]
+ mask[i, blind] = True
+
+ waveforms[mask] = 0
+ metadata[mask] = 0
+
+ return waveforms, metadata
+
+ def _handle_stations_without_trigger(self, waveforms, metadata):
+ """Handle stations without trigger signal."""
+ stations_without_trigger = (metadata != 0).any(axis=2) & (waveforms == 0).all(
+ axis=(2, 3)
+ )
+
+ if self.disable_station_foreshadowing:
+ metadata[stations_without_trigger] = 0
+ else:
+ waveforms[stations_without_trigger, 0, 0] += 1e-9
+
+ return waveforms, metadata
+
+ def _ensure_no_empty_arrays(self, waveforms, metadata):
+ """Ensure there are no empty arrays in the batch."""
+ mask = np.logical_and(
+ (metadata == 0).all(axis=(1, 2)), (waveforms == 0).all(axis=(1, 2, 3))
+ )
+ waveforms[mask, 0, 0, 0] = 1e-9
+ metadata[mask, 0, 0] = 1e-9
+
+ return waveforms, metadata
+
+ def _prepare_model_io(
+ self, waveforms, metadata, magnitude, target, pga_targets_data, pga_values
+ ):
+ """Prepare model inputs and outputs."""
+ inputs = [
+ mindspore.tensor(waveforms, dtype=mindspore.float32),
+ mindspore.tensor(metadata, dtype=mindspore.float32),
+ ]
+ outputs = []
+
+ if not self.no_event_token:
+ outputs += [mindspore.tensor(magnitude, dtype=mindspore.float32)]
+
+ if self.coords_target:
+ target = np.expand_dims(target, axis=-1)
+ outputs += [mindspore.tensor(target, dtype=mindspore.float32)]
+
+ if self.pga_targets and pga_values is not None and pga_targets_data is not None:
+ inputs += [mindspore.tensor(pga_targets_data, dtype=mindspore.float32)]
+ outputs += [mindspore.tensor(pga_values, dtype=mindspore.float32)]
+
+ return inputs, outputs
+
+ def location_transformation(self, metadata, target=None):
+ """
+ Apply transformations to the metadata and optionally to the target.
+ Adjusts positions based on a positional offset and scales the data if required.
+ """
+ transform_target_only = self.transform_target_only
+ metadata = metadata.copy()
+
+ metadata_old = metadata
+ metadata = metadata.copy()
+ mask = (metadata == 0).all(axis=2)
+
+ if target is not None:
+ target[:, 0] -= self.pos_offset[0]
+ target[:, 1] -= self.pos_offset[1]
+
+ metadata[:, :, 0] -= self.pos_offset[0]
+ metadata[:, :, 1] -= self.pos_offset[1]
+ if self.scale_metadata:
+ metadata[:, :, :2] *= D2KM
+ if target is not None:
+ target[:, :2] *= D2KM
+ metadata[mask] = 0
+ if self.scale_metadata:
+ metadata /= 100
+ if target is not None:
+ target /= 100
+ if transform_target_only:
+ metadata = metadata_old
+ if target is None:
+ return metadata
+ return metadata, target
+
+
+class PreloadedEventGenerator(Dataset):
+ """
+ A custom PyTorch Dataset class designed to generate preloaded event data for training or evaluation.
+ This class wraps an `EarthquakeDataset` and a `DataProcessor` to provide processed input-output pairs.
+ Attributes:
+ dataset (EarthquakeDataset): An instance of the EarthquakeDataset class, responsible for loading
+ raw earthquake-related data.
+ processor (DataProcessor): An instance of the DataProcessor class, responsible for processing
+ the raw data into model-ready inputs and outputs.
+ """
+
+ def __init__(self, data_path, event_key, data, event_metadata, **kwargs):
+ """
+ Initializes the PreloadedEventGenerator.
+ Args:
+ data_path (str): The file path or directory where the dataset is stored.
+ event_key (str): A key used to identify specific events within the dataset.
+ data (dict or array-like): Raw data associated with the events.
+ event_metadata (dict or DataFrame): Metadata describing the events in the dataset.
+ **kwargs: Additional keyword arguments passed to both EarthquakeDataset and DataProcessor.
+ """
+ super(PreloadedEventGenerator, self).__init__()
+ self.dataset = EarthquakeDataset(
+ data_path=data_path,
+ event_key=event_key,
+ data=data,
+ event_metadata=event_metadata,
+ **kwargs,
+ )
+ self.processor = DataProcessor(**kwargs)
+
+ def __len__(self):
+ """
+ Returns the total number of samples in the dataset.
+
+ Returns:
+ int: The length of the underlying EarthquakeDataset.
+ """
+ return len(self.dataset)
+
+ def __getitem__(self, index):
+ """
+ Retrieves and processes a single batch of data at the given index.
+
+ Args:
+ index (int): The index of the data sample to retrieve.
+
+ Returns:
+ tuple: A tuple containing two elements:
+ inputs: Processed input data ready for model consumption.
+ outputs: Corresponding target outputs for the model.
+ """
+ batch_data = self.dataset[index]
+ inputs, outputs = self.processor.process_batch(batch_data)
+
+ return inputs, outputs
+
+
+def generator_from_config(
+ config,
+ data,
+ data_path,
+ event_key,
+ event_metadata,
+ time,
+ sampling_rate=100,
+ pga=False,
+):
+ """init generator"""
+ generator_params = config["data"]
+ cutout = int(sampling_rate * (generator_params["noise_seconds"] + time))
+ cutout = (cutout, cutout + 1)
+
+ n_pga_targets = config["model"].get("n_pga_targets", 0)
+ if "data_path" in generator_params:
+ del generator_params["data_path"]
+
+ generator = PreloadedEventGenerator(
+ data_path=data_path,
+ event_key=event_key,
+ data=data,
+ event_metadata=event_metadata,
+ coords_target=True,
+ cutout=cutout,
+ pga_targets=n_pga_targets,
+ sampling_rate=sampling_rate,
+ select_first=True,
+ shuffle=False,
+ pga_mode=pga,
+ **generator_params,
+ )
+ return generator
diff --git a/MindEarth/applications/earthquake/G-TEAM/src/forcast.py b/MindEarth/applications/earthquake/G-TEAM/src/forcast.py
new file mode 100644
index 0000000000000000000000000000000000000000..2c67ab33b3ac1d27f13c9dd34e15175c90ef6dd0
--- /dev/null
+++ b/MindEarth/applications/earthquake/G-TEAM/src/forcast.py
@@ -0,0 +1,145 @@
+# Copyright 2023 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"GTeam inference"
+import numpy as np
+
+from src.utils import (
+ predict_at_time,
+ calc_mag_stats,
+ calc_loc_stats,
+ calc_pga_stats,
+)
+from src.data import load_data
+from src.visual import generate_true_pred_plot
+
+
+class GTeamInference:
+ """
+ Initialize the GTeamInference class.
+ """
+
+ def __init__(self, model_ins, cfg, output_dir, logger):
+ """
+ Args:
+ model_ins: The model instance used for inference.
+ cfg: Configuration dictionary containing model and data parameters.
+ output_dir: Directory to save the output results.
+ Attributes:
+ model: The model instance for inference.
+ cfg: Configuration dictionary.
+ output_dir: Directory to save outputs.
+ pga: Flag indicating if PGA (Peak Ground Acceleration) is enabled.
+ generator_params: Parameters for data generation.
+ model_params: Parameters specific to the model.
+ mag_key: Key for magnitude-related data.
+ pos_offset: Position offset for location predictions.
+ mag_stats: List to store magnitude prediction statistics.
+ loc_stats: List to store location prediction statistics.
+ pga_stats: List to store PGA prediction statistics.
+ """
+ self.model = model_ins
+ self.cfg = cfg
+ self.output_dir = output_dir
+ self.logger = logger
+ self.pga = cfg["model"].get("pga", "true")
+ self.generator_params = cfg["data"]
+ self.model_params = cfg["model"]
+ self.output_dir = output_dir
+ self.mag_key = self.generator_params["key"]
+ self.pos_offset = self.generator_params["pos_offset"]
+ self.mag_stats = []
+ self.loc_stats = []
+ self.pga_stats = []
+
+ def _parse_predictions(self, pred):
+ """
+ Parse the raw predictions into magnitude, location, and PGA components.
+ """
+ mag_pred = pred[0]
+ loc_pred = pred[1]
+ pga_pred = pred[2] if self.pga else []
+ return mag_pred, loc_pred, pga_pred
+
+ def _process_predictions(
+ self, mag_pred, loc_pred, pga_pred, time, evt_metadata, pga_true
+ ):
+ """
+ Process the parsed predictions to compute statistics and generate plots.
+ """
+ mag_pred_np = [t[0].asnumpy() for t in mag_pred]
+ mag_pred_reshaped = np.concatenate(mag_pred_np, axis=0)
+
+ loc_pred_np = [t[0].asnumpy() for t in loc_pred]
+ loc_pred_reshaped = np.array(loc_pred_np)
+
+ pga_pred_np = [t.asnumpy() for t in pga_pred]
+ pga_pred_reshaped = np.concatenate(pga_pred_np, axis=0)
+ pga_true_reshaped = np.log(
+ np.abs(np.concatenate(pga_true, axis=0).reshape(-1, 1))
+ )
+
+ if not self.model_params["no_event_token"]:
+ self.mag_stats += calc_mag_stats(
+ mag_pred_reshaped, evt_metadata, self.mag_key
+ )
+
+ self.loc_stats += calc_loc_stats(
+ loc_pred_reshaped, evt_metadata, self.pos_offset
+ )
+
+ generate_true_pred_plot(
+ mag_pred_reshaped,
+ evt_metadata[self.mag_key].values,
+ time,
+ self.output_dir,
+ )
+ self.pga_stats = calc_pga_stats(pga_pred_reshaped, pga_true_reshaped)
+
+ def _save_results(self):
+ """
+ Save the final results (magnitude, location, and PGA statistics) to a JSON file.
+ """
+ times = self.cfg["model"].get("times")
+ self.logger.info("times: {}".format(times))
+ self.logger.info("mag_stats: {}".format(self.mag_stats))
+ self.logger.info("loc_stats: {}".format(self.loc_stats))
+ self.logger.info("pga_stats: {}".format(self.pga_stats))
+
+ def test(self):
+ """
+ Perform inference for all specified times, process predictions, and save results.
+ This method iterates over the specified times, performs predictions, processes
+ the results, and saves the final statistics.
+ """
+ data_data, evt_key, evt_metadata, meta_data, data_path = load_data(self.cfg)
+ pga_true = data_data["pga"]
+ for time in self.cfg["model"].get("times"):
+ pred = predict_at_time(
+ self.model,
+ time,
+ data_data,
+ data_path,
+ evt_key,
+ evt_metadata,
+ config=self.cfg,
+ pga=self.pga,
+ sampling_rate=meta_data["sampling_rate"],
+ )
+ mag_pred, loc_pred, pga_pred = self._parse_predictions(pred)
+ self._process_predictions(
+ mag_pred, loc_pred, pga_pred, time, evt_metadata, pga_true
+ )
+ self._save_results()
+ print("Inference completed and results saved")
diff --git a/MindEarth/applications/earthquake/G-TEAM/src/models.py b/MindEarth/applications/earthquake/G-TEAM/src/models.py
new file mode 100644
index 0000000000000000000000000000000000000000..d7a650c9d266161043ee21ea2a7abf3ed3955816
--- /dev/null
+++ b/MindEarth/applications/earthquake/G-TEAM/src/models.py
@@ -0,0 +1,408 @@
+# Copyright 2023 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"GTeam model"
+import numpy as np
+
+import mindspore as ms
+import mindspore.nn as nn
+import mindspore.ops as ops
+
+
+class MLP(nn.Cell):
+ """
+ A Multi-Layer Perceptron (MLP) class using MindSpore's nn.Cell.
+ Parameters:
+ input_shape: Tuple representing the shape of the input data.
+ dims: Tuple containing the dimensions of each layer. Default is (100, 50).
+ final_activation: The activation function for the final layer. Default is nn.ReLU.
+ """
+
+ def __init__(self, input_shape, dims=(100, 50), final_activation=nn.ReLU()):
+ super().__init__()
+ layers = []
+ in_dim = input_shape[0]
+
+ for dim in dims[:-1]:
+ layers.append(nn.Dense(in_dim, dim))
+ layers.append(nn.ReLU())
+ in_dim = dim
+ layers.append(nn.Dense(in_dim, dims[-1]))
+
+ if final_activation:
+ layers.append(final_activation)
+ self.model = nn.SequentialCell(*layers)
+
+ def construct(self, x):
+ """
+ Forward pass through the network.
+ Parameters:
+ x: Input data to the MLP.
+ Returns:
+ Output after passing through the MLP.
+ """
+ return self.model(x)
+
+
+class NormalizedScaleEmbedding(nn.Cell):
+ """
+ A neural network module that normalizes input data, extracts features using a series of
+ convolutional and pooling layers, and processes the features through a multi-layer perceptron (MLP).
+ """
+
+ def __init__(self, downsample=5, mlp_dims=(500, 300, 200, 150), eps=1e-8):
+ """
+ Initialize the module with given parameters.
+ Parameters:
+ :downsample: Downsampling factor for the first convolutional layer.
+ :mlp_dims: Dimensions for the MLP layers.
+ :eps: A small value for numerical stability.
+ """
+ super().__init__()
+ self.downsample = downsample
+ self.mlp_dims = mlp_dims
+ self.eps = eps
+
+ self.conv2d_1 = nn.Conv2d(
+ 1,
+ 8,
+ kernel_size=(downsample, 1),
+ stride=(downsample, 1),
+ has_bias=True,
+ pad_mode="pad",
+ )
+ self.conv2d_2 = nn.Conv2d(
+ 8, 32, kernel_size=(16, 3), stride=(1, 1), has_bias=True, pad_mode="pad"
+ )
+
+ self.conv1d_1 = nn.Conv1d(32, 64, kernel_size=16, has_bias=True, pad_mode="pad")
+ self.maxpool_1 = nn.MaxPool1d(kernel_size=2, stride=2)
+ self.conv1d_2 = nn.Conv1d(
+ 64, 128, kernel_size=16, has_bias=True, pad_mode="pad"
+ )
+ self.maxpool_2 = nn.MaxPool1d(kernel_size=2, stride=2)
+ self.conv1d_3 = nn.Conv1d(128, 32, kernel_size=8, has_bias=True, pad_mode="pad")
+ self.maxpool_3 = nn.MaxPool1d(kernel_size=2, stride=2)
+ self.conv1d_4 = nn.Conv1d(32, 32, kernel_size=8, has_bias=True, pad_mode="pad")
+ self.conv1d_5 = nn.Conv1d(32, 16, kernel_size=4, has_bias=True, pad_mode="pad")
+
+ self.flatten = nn.Flatten()
+ self.mlp = MLP((865,), dims=self.mlp_dims)
+ self.leaky_relu = nn.LeakyReLU(alpha=0.01)
+
+ def construct(self, x):
+ """
+ Forward pass through the network.
+ :param x: Input tensor.
+ :return: Processed output tensor.
+ """
+ original_input = x
+ x = (
+ x
+ / (
+ ops.max(
+ ops.max(ops.abs(x), axis=1, keepdims=True)[0], axis=2, keepdims=True
+ )[0]
+ + self.eps
+ )
+ + self.eps
+ )
+ x = ops.unsqueeze(x, dim=1)
+
+ scale = (
+ ops.log(
+ ops.max(ops.max(ops.abs(original_input), axis=1)[0], axis=1)[0]
+ + self.eps
+ )
+ / 100
+ + self.eps
+ )
+ scale = ops.unsqueeze(scale, dim=1)
+
+ x = self.leaky_relu(self.conv2d_1(x))
+ x = self.leaky_relu(self.conv2d_2(x))
+
+ tmp_x = ops.Squeeze(axis=-1)
+ x = tmp_x(x)
+ x = self.leaky_relu(self.conv1d_1(x))
+ x = self.maxpool_1(x)
+ x = self.leaky_relu(self.conv1d_2(x))
+ x = self.maxpool_2(x)
+ x = self.leaky_relu(self.conv1d_3(x))
+ x = self.maxpool_3(x)
+ x = self.leaky_relu(self.conv1d_4(x))
+ x = self.leaky_relu(self.conv1d_5(x))
+
+ x = self.flatten(x)
+ x = ops.cat((x, scale), axis=1)
+ x = self.mlp(x)
+ return x
+
+
+class TransformerEncoder(nn.Cell):
+ """
+ TransformerEncoder class, used to implement the Transformer encoder.
+ Parameters:
+ d_model: Dimension of the input data.
+ nhead: Number of heads in multi-head attention.
+ num_layers: Number of layers in the encoder.
+ batch_first: Whether to consider the first dimension of the input data as the batch dimension.
+ activation: Type of activation function.
+ dim_feedforward: Dimension of the hidden layer in the feedforward network.
+ dropout: Proportion of dropout.
+ Methods:
+ __init__: Initialize the TransformerEncoder object.
+ construct: Construct the TransformerEncoder network.
+ """
+
+ def __init__(
+ self,
+ d_model=500,
+ nhead=10,
+ num_layers=6,
+ batch_first=True,
+ activation="gelu",
+ dim_feedforward=1000,
+ dropout=0.0,
+ ):
+ super().__init__()
+ self.encoder_layer = nn.TransformerEncoderLayer(
+ d_model=d_model,
+ nhead=nhead,
+ batch_first=batch_first,
+ dim_feedforward=dim_feedforward,
+ dropout=dropout,
+ activation=activation,
+ )
+ self.transformer_encoder = nn.TransformerEncoder(
+ self.encoder_layer, num_layers=num_layers
+ )
+
+ def construct(self, x, src_key_padding_mask=None):
+ """Construct the TransformerEncoder network"""
+ return self.transformer_encoder(x, src_key_padding_mask=src_key_padding_mask)
+
+
+class PositionEmbedding(nn.Cell):
+ """
+ PositionEmbedding class, used to implement position embeddings.
+ Parameters:
+ wavelengths: Range of wavelengths.
+ emb_dim: Dimension of the embeddings.
+ Methods:
+ __init__: Initialize the PositionEmbedding object.
+ construct: Construct the PositionEmbedding network.
+ """
+
+ def __init__(self, wavelengths, emb_dim):
+ super().__init__()
+ self.wavelengths = wavelengths
+ self.emb_dim = emb_dim
+
+ min_lat, max_lat = wavelengths[0]
+ min_lon, max_lon = wavelengths[1]
+ min_depth, max_depth = wavelengths[2]
+ lat_dim = emb_dim // 5
+ lon_dim = emb_dim // 5
+ depth_dim = emb_dim // 10
+ self.lat_coeff = (
+ 2
+ * np.pi
+ * 1.0
+ / min_lat
+ * ((min_lat / max_lat) ** (np.arange(lat_dim) / lat_dim))
+ )
+ self.lon_coeff = (
+ 2
+ * np.pi
+ * 1.0
+ / min_lon
+ * ((min_lon / max_lon) ** (np.arange(lon_dim) / lon_dim))
+ )
+ self.depth_coeff = (
+ 2
+ * np.pi
+ * 1.0
+ / min_depth
+ * ((min_depth / max_depth) ** (np.arange(depth_dim) / depth_dim))
+ )
+ lat_sin_mask = np.arange(emb_dim) % 5 == 0
+ lat_cos_mask = np.arange(emb_dim) % 5 == 1
+ lon_sin_mask = np.arange(emb_dim) % 5 == 2
+ lon_cos_mask = np.arange(emb_dim) % 5 == 3
+
+ depth_sin_mask = np.arange(emb_dim) % 10 == 4
+ depth_cos_mask = np.arange(emb_dim) % 10 == 9
+
+ self.mask = np.zeros(emb_dim)
+ self.mask[lat_sin_mask] = np.arange(lat_dim)
+ self.mask[lat_cos_mask] = lat_dim + np.arange(lat_dim)
+ self.mask[lon_sin_mask] = 2 * lat_dim + np.arange(lon_dim)
+ self.mask[lon_cos_mask] = 2 * lat_dim + lon_dim + np.arange(lon_dim)
+ self.mask[depth_sin_mask] = 2 * lat_dim + 2 * lon_dim + np.arange(depth_dim)
+ self.mask[depth_cos_mask] = (
+ 2 * lat_dim + 2 * lon_dim + depth_dim + np.arange(depth_dim)
+ )
+ self.mask = ms.tensor(self.mask.astype("int32"))
+
+ def construct(self, x):
+ """position embedding"""
+ lat_base = x[:, :, 0:1] * ms.tensor(self.lat_coeff, dtype=ms.float32)
+ lon_base = x[:, :, 1:2] * ms.tensor(self.lon_coeff, dtype=ms.float32)
+ depth_base = x[:, :, 2:3] * ms.tensor(self.depth_coeff, dtype=ms.float32)
+
+ output = ops.cat(
+ [
+ ops.sin(lat_base),
+ ops.cos(lat_base),
+ ops.sin(lon_base),
+ ops.cos(lon_base),
+ ops.sin(depth_base),
+ ops.cos(depth_base),
+ ],
+ axis=-1,
+ )
+ output = ops.index_select(output, axis=-1, index=self.mask)
+
+ return output
+
+
+class AddEventToken(nn.Cell):
+ """
+ AddEventToken class, used to implement adding event tokens.
+
+ Parameters:
+ emb_dim: Dimension of the embeddings.
+ init_range: Initialization range.
+
+ Methods:
+ __init__: Initialize the AddEventToken object.
+ construct: Construct the AddEventToken network.
+ """
+
+ def __init__(self, emb_dim, init_range):
+ super().__init__()
+ self.emb_dim = emb_dim
+ init_value = np.random.uniform(-init_range, init_range, (1, 1, emb_dim)).astype(
+ np.float32
+ )
+ self.event_token = ms.Parameter(ms.Tensor(init_value), name="event_token")
+
+ def construct(self, x):
+ """add eventtoken"""
+ event_token = self.event_token
+ pad = ops.ones_like(x[:, :1, :]) * event_token
+ x = ops.cat([pad, x], axis=1)
+
+ return x
+
+
+def _init_pad_mask(waveforms, pga_targets):
+ """
+ _init_pad_mask function, used to initialize the padding mask.
+ """
+ station_pad_mask = abs(waveforms) < 1e-8
+ station_pad_mask = ops.all(station_pad_mask, axis=2)
+ station_pad_mask = ops.all(station_pad_mask, axis=2)
+
+ event_token_mask = ops.zeros((station_pad_mask.shape[0], 1), dtype=ms.dtype.bool_)
+ pad_mask = ops.cat([event_token_mask, station_pad_mask], axis=1)
+
+ target_pad_mask = ms.numpy.ones_like(pga_targets, dtype=ms.dtype.bool_)
+ target_pad_mask = ops.all(target_pad_mask, 2)
+
+ pad_mask = ops.cat((pad_mask, target_pad_mask), axis=1)
+
+ return pad_mask
+
+
+class WaveformFullmodel(nn.Cell):
+ """
+ Waveform full model class, used for processing and predicting waveform data."
+ """
+
+ def __init__(
+ self,
+ waveform_model_dims=(500, 500, 500),
+ output_mlp_dims=(150, 100, 50, 30, 10),
+ output_location_dims=(150, 100, 50, 50, 50),
+ wavelength=((0.01, 10), (0.01, 10), (0.01, 10)),
+ n_heads=10,
+ hidden_dim=1000,
+ transformer_layers=6,
+ hidden_dropout=0.0,
+ n_pga_targets=0,
+ downsample=5,
+ ):
+ super().__init__()
+ self.waveform_model = NormalizedScaleEmbedding(
+ downsample=downsample, mlp_dims=waveform_model_dims
+ )
+ self.transformer = TransformerEncoder(
+ d_model=waveform_model_dims[-1],
+ nhead=n_heads,
+ num_layers=transformer_layers,
+ dim_feedforward=hidden_dim,
+ dropout=hidden_dropout,
+ )
+
+ self.mlp_mag = MLP((waveform_model_dims[-1],), output_mlp_dims)
+ self.mlp_loc = MLP(
+ (waveform_model_dims[-1],), output_location_dims, final_activation=None
+ )
+ self.mlp_pga = MLP(
+ (waveform_model_dims[-1],), output_mlp_dims, final_activation=None
+ )
+
+ self.position_embedding = PositionEmbedding(
+ wavelengths=wavelength, emb_dim=waveform_model_dims[-1]
+ )
+ self.addeventtoken = AddEventToken(emb_dim=500, init_range=0.02)
+ self.n_pga_targets = n_pga_targets
+
+ def cal_waveforms_emb_normalized(self, waveforms_emb):
+ """Normalize the waveform embeddings"""
+ mean_vals = waveforms_emb.mean(axis=2, keep_dims=True)
+ std_vals = waveforms_emb.std(axis=2, keepdims=True)
+ waveforms_emb_normalized = (waveforms_emb - mean_vals) / (std_vals + 1e-8)
+ return waveforms_emb_normalized
+
+ def construct(self, waveforms, metadata, pga_targets):
+ """
+ Construct method to process the input waveforms, metadata, and PGA targets.
+ """
+ batch_size, num_stations, seq_length, num_channels = waveforms.shape
+ waveforms_reshape = waveforms.reshape(-1, seq_length, num_channels)
+
+ waveforms_emb = self.waveform_model(waveforms_reshape)
+ waveforms_emb = waveforms_emb.reshape(batch_size, num_stations, -1)
+ waveforms_emb_normalized = self.cal_waveforms_emb_normalized(waveforms_emb)
+ coords_emb = self.position_embedding(metadata)
+ pga_target_emb = self.position_embedding(pga_targets)
+ pad_mask = _init_pad_mask(waveforms, pga_targets)
+
+ emb_pos = waveforms_emb_normalized + coords_emb
+ emb_pos = self.addeventtoken(emb_pos)
+ emb_pos_pga = ops.cat((emb_pos, pga_target_emb), axis=1)
+ emb_pos_pga_trans = self.transformer(emb_pos_pga, pad_mask)
+ emb_pga = emb_pos_pga_trans[:, -self.n_pga_targets :, :]
+ emb_mag_loc = emb_pos_pga_trans[:, 0, :]
+
+ mag = self.mlp_mag(emb_mag_loc)
+ loc = self.mlp_loc(emb_mag_loc)
+
+ pga_all = self.mlp_pga(emb_pga)
+ outputs = [mag, loc, pga_all]
+
+ return outputs
diff --git a/MindEarth/applications/earthquake/G-TEAM/src/utils.py b/MindEarth/applications/earthquake/G-TEAM/src/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..f1a4ae25019e61ffeeb4621f2d2ccabd6d36ecfe
--- /dev/null
+++ b/MindEarth/applications/earthquake/G-TEAM/src/utils.py
@@ -0,0 +1,180 @@
+# Copyright 2023 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"GTeam util"
+import os
+import copy
+import numpy as np
+import sklearn.metrics as metrics
+from geopy.distance import geodesic
+
+import mindspore as ms
+import mindspore.ops as ops
+
+from src import data
+from src.data import generator_from_config, D2KM
+from src.models import WaveformFullmodel
+from mindearth.utils import create_logger
+
+
+def predict_at_time(
+ model,
+ time,
+ data_data,
+ data_path,
+ event_key,
+ event_metadata,
+ config,
+ sampling_rate=100,
+ pga=False,
+):
+ """Predict at a specific time point"""
+ generator = generator_from_config(
+ config,
+ data_data,
+ data_path,
+ event_key,
+ event_metadata,
+ time,
+ sampling_rate=sampling_rate,
+ pga=pga,
+ )
+
+ pred_list_mag = []
+ pred_list_loc = []
+ pred_list_pga = []
+ for i in range(len(generator)):
+ x, _ = generator[i]
+
+ pred = model(x[0], x[1], x[2])
+ pred_list_mag.append(pred[0])
+ pred_list_loc.append(pred[1])
+ pred_list_pga.append(pred[2])
+
+ pre_mag = ops.cat(pred_list_mag, axis=0)
+ pre_loc = ops.cat(pred_list_loc, axis=0)
+ pre_pga = ops.cat(pred_list_pga, axis=0)
+ predictions = [pre_mag, pre_loc, pre_pga]
+
+ mag_pred_filter = []
+ loc_pred_filter = []
+ pga_pred_filter = []
+
+ for i, (start, end) in enumerate(
+ zip(generator.dataset.reverse_index[:-1], generator.dataset.reverse_index[1:])
+ ):
+ sample_mag_pred = predictions[0][start:end].reshape(
+ (-1,) + predictions[0].shape[-1:]
+ )
+ sample_mag_pred = sample_mag_pred[: len(generator.dataset.pga[i])]
+ mag_pred_filter += [sample_mag_pred]
+
+ sample_loc_pred = predictions[1][start:end].reshape(
+ (-1,) + predictions[1].shape[-1:]
+ )
+ sample_loc_pred = sample_loc_pred[: len(generator.dataset.pga[i])]
+ loc_pred_filter += [sample_loc_pred]
+
+ sample_pga_pred = predictions[2][start:end].reshape(
+ (-1,) + predictions[2].shape[-1:]
+ )
+ sample_pga_pred = sample_pga_pred[: len(generator.dataset.pga[i])]
+ pga_pred_filter += [sample_pga_pred]
+
+ preds = [mag_pred_filter, loc_pred_filter, pga_pred_filter]
+
+ return preds
+
+
+def calc_mag_stats(mag_pred, event_metadata, key):
+ """Calculate statistical information for magnitude predictions"""
+ mean_mag = mag_pred
+ true_mag = event_metadata[key].values
+ # R^2
+ r2 = metrics.r2_score(true_mag, mean_mag)
+ # RMSE
+ rmse = np.sqrt(metrics.mean_squared_error(true_mag, mean_mag))
+ # MAE
+ mae = metrics.mean_absolute_error(true_mag, mean_mag)
+ return r2, rmse, mae
+
+
+def calc_pga_stats(pga_pred, pga_true, suffix=""):
+ """Calculate statistical information for PGA predictions"""
+ if suffix:
+ suffix += "_"
+ valid_mask = np.isfinite(pga_true) & np.isfinite(pga_pred)
+ pga_true_clean = pga_true[valid_mask]
+ pga_pred_clean = pga_pred[valid_mask]
+ r2 = metrics.r2_score(pga_true_clean, pga_pred_clean)
+ rmse = np.sqrt(metrics.mean_squared_error(pga_true_clean, pga_pred_clean))
+ mae = metrics.mean_absolute_error(pga_true_clean, pga_pred_clean)
+
+ return [r2, rmse, mae]
+
+
+def calc_loc_stats(loc_pred, event_metadata, pos_offset):
+ """Calculate statistical information for location predictions"""
+ coord_keys = data.detect_location_keys(event_metadata.columns)
+ true_coords = event_metadata[coord_keys].values
+ mean_coords = loc_pred
+ mean_coords *= 100
+ mean_coords[:, :2] /= D2KM
+ mean_coords[:, 0] += pos_offset[0]
+ mean_coords[:, 1] += pos_offset[1]
+
+ dist_epi = np.zeros(len(mean_coords))
+ dist_hypo = np.zeros(len(mean_coords))
+ real_dep = np.zeros(len(mean_coords))
+ pred_dep = np.zeros(len(mean_coords))
+ for i, (pred_coord, true_coord) in enumerate(zip(mean_coords, true_coords)):
+ dist_epi[i] = geodesic(pred_coord[:2], true_coord[:2]).km
+ dist_hypo[i] = np.sqrt(dist_epi[i] ** 2 + (pred_coord[2] - true_coord[2]) ** 2)
+ real_dep[i] = true_coord[2]
+ pred_dep[i] = pred_coord[2]
+
+ rmse_epi = np.sqrt(np.mean(dist_epi**2))
+ mae_epi = np.mean(np.abs(dist_epi))
+
+ rmse_hypo = np.sqrt(np.mean(dist_hypo**2))
+ mae_hypo = np.mean(dist_hypo)
+
+ return rmse_hypo, mae_hypo, rmse_epi, mae_epi
+
+
+def init_model(arg):
+ """set model"""
+ tmpcfg = copy.deepcopy(arg["model"])
+ tmpcfg.pop("no_event_token")
+ tmpcfg.pop("run_with_less_data")
+ tmpcfg.pop("pga")
+ tmpcfg.pop("mode")
+ tmpcfg.pop("times")
+ model = WaveformFullmodel(**tmpcfg)
+ param_dict = ms.load_checkpoint(arg["summary"].get("ckpt_path"))
+ # Load parameters into the network
+ ms.load_param_into_net(model, param_dict)
+ model.set_train(False)
+ return model
+
+
+def get_logger(config):
+ """Get logger for saving log"""
+ summary_params = config.get("summary")
+ logger = create_logger(
+ path=os.path.join(summary_params.get("summary_dir"), "results.log")
+ )
+ for key in config:
+ logger.info(config[key])
+ return logger
diff --git a/MindEarth/applications/earthquake/G-TEAM/src/visual.py b/MindEarth/applications/earthquake/G-TEAM/src/visual.py
new file mode 100644
index 0000000000000000000000000000000000000000..ef3dc119504fb9d8ef48bbd65dcebf626d60d044
--- /dev/null
+++ b/MindEarth/applications/earthquake/G-TEAM/src/visual.py
@@ -0,0 +1,79 @@
+# Copyright 2023 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"visualization"
+import os
+import matplotlib.pyplot as plt
+import numpy as np
+import sklearn.metrics as metrics
+
+
+def generate_true_pred_plot(pred_values, true_values, time, path, suffix=""):
+ """
+ Generate a plot comparing true values and predicted values, and calculate
+ evaluation metrics including MAE, RMSE, R^2, and the standard deviation of residuals.
+ Parameters:
+ pred_values: List of predicted values
+ true_values: List of true values
+ time: Time, used for naming the image
+ path: Path to save the image
+ suffix: Suffix for image naming, default is an empty string
+ """
+ if suffix:
+ suffix += "_"
+ fig = plt.figure(figsize=(9, 9))
+ plt.plot(true_values, pred_values, "ok", alpha=0.2)
+ pred_value = pred_values
+ pred_value = np.array([x for x in pred_value])
+ r2 = metrics.r2_score(true_values, pred_value)
+ rmse = np.sqrt(metrics.mean_squared_error(true_values, pred_value))
+ mae = metrics.mean_absolute_error(true_values, pred_value)
+
+ plt.text(
+ 0.6,
+ 6,
+ f"MAE: {mae:.2f}\nRMSE: {rmse:.2f}\n$R^{2}$: {r2:.2f}",
+ fontsize=30,
+ verticalalignment="top",
+ horizontalalignment="left",
+ )
+ plt.plot(np.arange(0, 8), np.arange(0, 8), "-r")
+ plt.xlim(0, 7)
+ plt.ylim(0, 7)
+ ax = plt.gca()
+ ax.set_xlabel("True values", fontsize=20)
+ ax.set_ylabel("Pred values", fontsize=20)
+ ax.set_title(str(time) + " s", fontsize=20)
+ fig.savefig(os.path.join(path, f"truepred_{suffix}{time}.png"), bbox_inches="tight")
+ plt.close()
+
+ residual = true_values - pred_value
+ fig = plt.figure(figsize=(9, 9))
+ axs = fig.subplots(1, 1)
+ axs.hist(residual)
+ axs.set_xlabel("residual", fontsize=25)
+ axs.set_ylabel("Event Number", fontsize=25)
+ x_lim = axs.get_xlim()
+ y_lim = axs.get_ylim()
+ plt.text(
+ x_lim[1] * 0.95,
+ y_lim[1] * 0.95,
+ f"MAE: {mae:.2f}\nRMSE: {rmse:.2f}\n$R^{{2}}$: {r2:.2f}\nSTD: {np.std(residual):.2f}",
+ fontsize=30,
+ verticalalignment="top",
+ horizontalalignment="right",
+ )
+
+ fig.savefig(os.path.join(path, f"Residual_{suffix}{time}.png"), bbox_inches="tight")
+ plt.close()