diff --git a/MindSPONGE/applications/research/Geneformer/README.md b/MindSPONGE/applications/research/Geneformer/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..30e664cd63d767d29c7b4221b54bf37ecab6cd9a
--- /dev/null
+++ b/MindSPONGE/applications/research/Geneformer/README.md
@@ -0,0 +1,134 @@
+
+## Geneformer模型classification下游任务
+
+## 概述
+
+基因网络的映射依赖于大量的转录组数据,但在罕见疾病或临床难以接触的组织中,这样的数据往往非常有限。数据稀缺阻碍了网络修正型药物的发现,因此需要一种能够在有限数据环境下进行有效预测的方法。如何在数据有限的环境下有效地映射基因网络结构并发现关键调控因子和潜在治疗靶标。这个问题源于基因网络的复杂性以及数据的稀缺性,特别是在罕见疾病或临床难以接触的组织中。
+
+研究者借鉴迁移学习的成功经验,开发了一种名为Geneformer的上下文感知、基于注意力机制的深度学习模型。该模型在大规模转录组数据上进行了预训练,获得了对网络动态的基本理解。Geneformer通过有限的特定任务训练示例进行微调,能够适用于与染色质和网络动态相关的多样化下游任务,提高了预测准确性。Geneformer的出现不仅解决了当前数据有限的问题,还为未来的基因网络研究和药物发现提供了新的方向和可能。它有望推动基因网络研究的发展,并为罕见疾病和复杂疾病的治疗提供新的希望。
+
+### 方法
+
+Geneformer是一个基于大规模人类组织单细胞转录组数据预训练的Transformer基础模型。该模型最初于2021年6月在包含约3000万个单细胞转录组的Genecorpus-30M数据集上进行预训练。为了促进解释,我们排除了具有高突变负荷的细胞(如恶性肿瘤细胞和永生化细胞系),因为这些细胞可能导致网络大规模重构,而缺少相应的基因组测序。然后,在2024年4月,Geneformer在约9500万个非癌转录组上进行了预训练,随后在约1400万个癌转录组上进行了持续学习,从而得到了一个针对癌症领域的调优模型。
+
+每个单细胞的转录组通过排序值编码的方式呈现给模型,其中基因根据其在该细胞中的表达量相对于整个Genecorpus-30M数据集中的表达量进行排序。排序值编码提供了该细胞转录组的非参数表示,并利用预训练数据集中每个基因表达的多次观察结果,优先考虑能够区分细胞状态的基因。具体来说,这种方法会降低普遍高表达的管家基因的优先级(将其排名降低),而像转录因子这样虽然表达量低但高度区分细胞状态的基因,在编码中的排名会提高。此外,这种基于排序的方法可能对技术伪影具有更强的鲁棒性,这些伪影可能会系统地影响绝对转录计数值,但每个细胞内基因的相对排序总体保持稳定。
+
+每个单细胞转录组的排序值编码随后通过N层Transformer编码器单元进行处理,其中N根据模型大小而变化。预训练是通过掩码学习目标实现的,其中每个转录组中有15%的基因被掩码,模型被训练为使用剩余未掩码基因的上下文来预测每个掩码位置应该是哪个基因。这种方法的主要优势在于它是完全自监督的,可以在完全未标记的数据上完成,从而允许包含大量训练数据,而不受伴随标签样本的限制。
+
+在预训练过程中,Geneformer获得了对网络动力学的基本理解,并以完全自监督的方式在模型的注意力权重中编码了网络层次结构。通过零样本学习和使用有限的任务特定数据进行微调,Geneformer在与染色质和网络动力学相关的一系列下游任务中持续提高了预测准确性。通过零样本学习的计算机模拟扰动,我们在心肌细胞中鉴定了一种新的转录因子,并通过实验验证其对心肌细胞产生收缩力的能力至关重要。使用有限的患者数据进行计算机模拟治疗,我们发现了心肌病候选治疗靶点,并在疾病的诱导多能干细胞(iPSC)模型中实验验证了这些靶点能显著改善心肌细胞产生收缩力的能力。总体而言,Geneformer是一个基于大规模人类单细胞转录组数据预训练的基础深度学习模型,它获得了对网络动力学的基本理解,现在可以将这一理解应用于广泛的下游任务,以加速发现关键网络调节因子和候选治疗靶点。
+
+如下图所示,在初始自监督大规模预训练的迁移学习时需将预训练的权重复制到每个微调任务的模型中,添加微调层,并使用有限的数据对特定的每个下游任务进行微调。通过在可推广的学习目标上进行单一的初始自监督大规模预训练,该模型获得了学习领域的基本知识,然后将其应用于与预训练学习目标不同的大量下游应用,将知识迁移到新任务中。
+
+
+
+在预训练的Geneformer架构。每个单细胞转录组被编码成排序值编码[秩编码],然后通过6层transformer编码器单元进行编码,参数如下:输入大小为2048(完全代表Geneformer-30M中排序值编码的93%),256个嵌入维度,每层四个注意力头,前馈大小为512。Geneformer在2048的输入大小上使用full dense 自注意力。可提取的输出包括上下文基因和细胞嵌入编码、上下文注意力权重和上下文预测。[排序值编码,基因是根据其在该细胞中的表达进行排序]。
+
+### 数据准备
+
+part1: 使用的数据集为Genecorpus-30M数据集,数据集下载路径为 https://huggingface.co/datasets/ctheodoris/Genecorpus-30M。
+
+part2: 下载**.pkl文件并拷贝到src目录下,下载路径为 https://huggingface.co/ctheodoris/Geneformer/tree/main/geneformer 。
+
+part3: 获取模型原始训练文件,获取路径为 https://huggingface.co/ctheodoris/Geneformer/tree/main/fine_tuned_models/gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224。
+
+### 目录结构
+
+```shell
+
+ geneformer # 模型名
+ ├── README.md # 模型说明文档
+ ├── scripts # 脚本文件
+ │ ├── run.sh # 分布式执行脚本
+ │ ├── run_8p.sh # 单卡执行脚本
+ │ ├── main.py # 单卡python脚本
+ │ └── convert_weight.py # 模型转换脚本
+ ├── src # 模型定义源码目录
+ │ ├── __init__.py # 输入数据集处理
+ │ ├── classifier.py # 训练验证函数定义
+ │ ├── classifier_utils.py # 计算处理
+ │ ├── preparedata.py # 数据处理
+ │ └── perturber_utils.py # 模型加载
+ └── configs # 案例配置目录
+ ├── geneformer_config.yaml # geneformer传参配置
+ └── run_geneformer_args.yaml # 案例配置yaml文件
+
+```
+
+### 参数配置
+
+```shell
+
+model: # 模型参数
+ model_config: # 模型配置参数
+ type: BertConfig # config种类
+ use_one_hot_embeddings: False # 是否使用one_hot编码
+ num_labels: 2 # 标签种类
+ dropout_prob: 0.02 # dropout概率参数
+ batch_size: 16 # 批次大小
+ seq_length: 2048 # 语句最大长度
+ vocab_size: 25426 # 词汇表大小
+ hidden_size: 256 # 隐藏层维度
+ num_hidden_layers: 6 # 隐藏层层数
+ num_attention_heads: 4 # 多头注意力头数
+ hidden_act: "relu" # 隐藏层激活函数
+ checkpoint_name_or_path: "" # 加载模型权重路径
+ arch:
+ type: BertForPreTraining # 任务种类
+lr_schedule: # 学习率参数
+ type: LinearWithWarmUpLR # 学习率策略
+ learning_rate: 0.00005 # 初始学习率
+ lr_end: 0.0000000001 # 最小学习率
+ layer_decay: 0.65 # 权重参数值
+optimizer: # 优化器参数
+ type: adamw # 优化器种类
+ weight_decay: 0.001 # 权重衰减力度
+callbacks:
+ type: MFLossMonitor # loss回调
+ type: CheckpointMonitor # 权重保存回调
+ prefix: "mindformers" # 权重名前缀
+ keep_checkpoint_max: 100 # 保存最大数量
+ save_checkpoint_steps: 100 # 保存权重步数
+
+```
+
+### 快速开始
+
+step1: 使用convert_weight.py文件将模型原始训练文件转为mindspore框架的.ckpt文件(根据实际模型层数修改layers参数,torch_path和mindspore_path分别对应原始训练文件和转换之后的mindspore框架文件)。
+
+```shell
+python3 scripts/convert_weight.py --layers 6 --torch_path pytorch_model.bin --mindspore_path ./out_model/geneformer_mindspore.ckpt
+```
+
+step2: 修改config/geneformer_config.yaml中的model_output目录为step1中的输出文件目录执行python脚本完成训练和验证。
+
+```shell
+cd scripts && bash run_8p.sh
+```
+
+### 结果展示
+
+将loss数据保存绘图后得到loss曲线如下所示:
+
+
+
+### 性能指标
+
+| 参数 | Ascend |
+| :----------------------: | :----------------------------------------------------------------------------------------------------------------------------------: |
+| 硬件资源 | 昇腾AI处理器 |
+| 框架版本 | mindspore 2.3.1 |
+| 数据集 | Genecorpus-30M |
+| 参数量 | 6L-30M-i2048 |
+| 训练参数 | batch_size=12,
steps_per_epoch=835,
epochs=1 |
+| 测试参数 | batch_size=16 |
+| 优化器 | AdamW |
+| Train steps/s | 12.34 |
+| Train runtimes | 9.60 |
+| Eval accuracy | 0.70 |
+| Eval F1 | 0.80 |
+
+### 引用
+
+[1] C V Theodoris#, L Xiao, A Chopra, M D Chaffin, Z R Al Sayed, M C Hill, H Mantineo, E Brydon, Z Zeng, X S Liu, P T Ellinor#. Transfer learning enables predictions in network biology. Nature, 31 May 2023. (#co-corresponding authors)
+
+[2] H Chen*, M S Venkatesh*, J Gomez Ortega, S V Mahesh, T Nandi, R Madduri, K Pelka†, C V Theodoris†#. Quantized multi-task learning for context-specific representations of gene network dynamics. bioRxiv, 19 Aug 2024. (*co-first authors, †co-senior authors, #corresponding author)
\ No newline at end of file
diff --git a/MindSPONGE/applications/research/Geneformer/config/geneformer_config.yaml b/MindSPONGE/applications/research/Geneformer/config/geneformer_config.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..c2a3ff9fc547593c643c9407e82d1afc2551b74b
--- /dev/null
+++ b/MindSPONGE/applications/research/Geneformer/config/geneformer_config.yaml
@@ -0,0 +1,7 @@
+# download Genecorpus-30M datasets from https://huggingface.co/datasets/ctheodoris/Genecorpus-30M/blob/main/example_input_files/gene_classification/dosage_sensitive_tfs/dosage_sensitivity_TFs.pickle and example_input_files/gene_classification/dosage_sensitive_tfs/gc-30M_sample50k.dataset
+gene_class_dict_path: "../dosage_sensitivity_TFs.pickle"
+dataset_path: "../gc-30M_sample50k.dataset"
+output_prefix: "tf_dosage_sens_test"
+output_dir: "../output_dir"
+data_output: "../output_data"
+model_output: "../output_model"
diff --git a/MindSPONGE/applications/research/Geneformer/config/run_geneformer_args.yaml b/MindSPONGE/applications/research/Geneformer/config/run_geneformer_args.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..f96e77f3c162721c1840222d4e2ba661e07bd819
--- /dev/null
+++ b/MindSPONGE/applications/research/Geneformer/config/run_geneformer_args.yaml
@@ -0,0 +1,138 @@
+model:
+ model_config:
+ type: BertConfig
+ use_one_hot_embeddings: False
+ num_labels: 2
+ dropout_prob: 0.02
+ batch_size: 16
+ seq_length: 2048 #length of input sentence
+ vocab_size: 25426 #size of vocab
+ hidden_size: 256 #size of text feature
+ num_hidden_layers: 6 #model depth
+ num_attention_heads: 4 #number of attention heads
+ intermediate_size: 512 #hidden_size*4
+ hidden_act: "relu" #activation
+ post_layernorm_residual: True #select postlayernorm or prelayernorm
+ hidden_dropout_prob: 0.02
+ attention_probs_dropout_prob: 0.02
+ max_position_embeddings: 2048
+ type_vocab_size: 2
+ initializer_range: 0.02
+ use_relative_positions: False
+ use_past: False
+ use_moe: False
+ compute_dtype: "float32"
+ checkpoint_name_or_path: ""
+ arch:
+ type: BertForPreTraining
+
+lr_schedule:
+ type: LinearWithWarmUpLR
+ learning_rate: 0.00005 # 5e-5
+ lr_end: 0.0000000001 # 1e-6
+ warmup_steps: 0
+ total_steps: -1 # -1 means it will load the total steps of the dataset
+layer_scale: False
+layer_decay: 0.65
+
+optimizer:
+ type: adamw
+ weight_decay: 0.001
+ eps: 0.00000001 # 1e-8
+lr_scale: False
+lr_scale_factor: 256
+
+callbacks:
+ - type: MFLossMonitor
+ - type: CheckpointMonitor
+ prefix: "mindformers"
+ keep_checkpoint_max: 100
+ save_checkpoint_steps: 500
+ integrated_save: True
+ async_save: False
+
+runner_config:
+ epochs: 1
+ batch_size: 12
+ sink_mode: False
+ sink_size: 2
+runner_wrapper:
+ type: TrainOneStepCell
+
+# parallel
+use_parallel: False
+parallel:
+ parallel_mode: 0 # 0-standalone, 1-semi, 2-auto, 3-hybrid
+ gradients_mean: True
+ enable_alltoall: False
+ full_batch: False
+ search_mode: "sharding_propagation"
+ enable_parallel_optimizer: False
+ strategy_ckpt_save_file: "./ckpt_strategy.ckpt"
+parallel_config:
+ data_parallel: 1
+ model_parallel: 1
+ expert_parallel: 1
+ pipeline_stage: 1
+ micro_batch_num: 1
+ gradient_aggregation_group: 4
+micro_batch_interleave_num: 1
+
+# profile
+profile: False
+profile_start_step: 1
+profile_stop_step: 10
+init_start_profile: False
+profile_communication: False
+profile_memory: True
+
+# Trainer
+trainer:
+ type: TokenClassificationTrainer
+ model_name: txtcls_bert_base_uncased
+do_eval: False
+
+# train dataset
+train_dataset: &train_dataset
+ input_columns: ["input_ids", "input_mask", "segment_ids", "label_ids"]
+ num_parallel_workers: 8
+ python_multiprocessing: False
+ drop_remainder: True
+ batch_size: 16
+ repeat: 1
+ numa_enable: False
+ prefetch_size: 1
+ seed: 42
+train_dataset_task:
+ type: TextClassificationDataset
+ dataset_config: *train_dataset
+
+# eval dataset
+eval_dataset: &eval_dataset
+ input_columns: ["input_ids", "input_mask", "segment_ids", "label_ids"]
+ num_parallel_workers: 8
+ python_multiprocessing: False
+ drop_remainder: True
+ batch_size: 64
+ repeat: 1
+ numa_enable: False
+ prefetch_size: 1
+ seed: 42
+eval_dataset_task:
+ type: TextClassificationDataset
+ dataset_config: *eval_dataset
+
+# processor
+processor:
+ return_tensors: ms
+ tokenizer:
+ cls_token: '[CLS]'
+ do_basic_tokenize: True
+ do_lower_case: True
+ mask_token: '[MASK]'
+ pad_token: '[PAD]'
+ sep_token: '[SEP]'
+ type: BertTokenizer
+ unk_token: '[UNK]'
+ type: BertProcessor
+top_k: 1
\ No newline at end of file
diff --git a/MindSPONGE/applications/research/Geneformer/images/loss.png b/MindSPONGE/applications/research/Geneformer/images/loss.png
new file mode 100644
index 0000000000000000000000000000000000000000..a3877d19d56612fa9cba4702e45a313dd41012a7
Binary files /dev/null and b/MindSPONGE/applications/research/Geneformer/images/loss.png differ
diff --git a/MindSPONGE/applications/research/Geneformer/images/model.png b/MindSPONGE/applications/research/Geneformer/images/model.png
new file mode 100644
index 0000000000000000000000000000000000000000..adb8879687cf0bbf087344ef2e2355cd533e8e8e
Binary files /dev/null and b/MindSPONGE/applications/research/Geneformer/images/model.png differ
diff --git a/MindSPONGE/applications/research/Geneformer/scripts/convert_weight.py b/MindSPONGE/applications/research/Geneformer/scripts/convert_weight.py
new file mode 100644
index 0000000000000000000000000000000000000000..718acf2d7c8d4c3f9b794b2fe04b8fa171457163
--- /dev/null
+++ b/MindSPONGE/applications/research/Geneformer/scripts/convert_weight.py
@@ -0,0 +1,199 @@
+# Copyright 2022 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Convert checkpoint from torch/huggingface"""
+import argparse
+import numpy as np
+import torch
+from mindspore import save_checkpoint, Tensor
+
+def generate_params_dict(total_layers,
+ mindspore_params_per_layer,
+ torch_params_per_layer,
+ mindspore_additional_params,
+ torch_additional_params):
+ """
+ Generate the total parameter mapping of mindspore and pytorch.
+
+ Args:
+ total_layers(int): The total layers of the net.
+ mindspore_params_per_layer(list): The list of params per layer for the net of mindspore.
+ torch_params_per_layer(list): The list of params per layer for the net of pytorch.
+ mindspore_additional_params(list): The list of params outside the layer for the net of mindspore
+ torch_additional_params(list): The list of params outside the layer for the net of pytorch.
+
+ Returns:
+ A list of tuple. The first element is the parameter name of mindspore,
+ the another is the parameter name of pytorch.
+ """
+ mapped_params = list(zip(mindspore_params_per_layer, torch_params_per_layer))
+ ms_extend_param_list = []
+ torch_extend_param_list = []
+ for i in range(total_layers):
+ for ms_para, torch_para in mapped_params:
+ src = ms_para.format(i)
+ tgt = torch_para.format(i)
+
+ ms_extend_param_list.append(src)
+ torch_extend_param_list.append(tgt)
+
+ mapped_params = list(zip(mindspore_additional_params, torch_additional_params))
+ for ms_para, torch_para in mapped_params:
+ ms_extend_param_list.append(ms_para)
+ torch_extend_param_list.append(torch_para)
+
+ return list(zip(ms_extend_param_list, torch_extend_param_list))
+
+def get_converted_ckpt(mapped_params, weight_dict):
+ """
+ Print the keys of the loaded checkpoint
+
+ Args:
+ mapped_params(dict): The loaded checkpoint. The key is parameter name and value is the numpy array.
+ weight_dict(dict): The loaded pytorch checkpoint.
+
+ Returns:
+ None
+ """
+ new_ckpt_list = []
+
+ # Currently, the ms_extend_param the torch_extend_param is the full parameters.
+ for src, tgt in mapped_params:
+ if tgt not in weight_dict:
+ if "LayerNorm.gamma" in tgt:
+ tgt = tgt.replace("gamma", "weight")
+ if "LayerNorm.beta" in tgt:
+ tgt = tgt.replace("beta", "bias")
+ try:
+ value = weight_dict[tgt].numpy()
+ if 'output.dense.weight' in tgt or 'intermediate.dense.weight' in tgt:
+ value = np.transpose(value, [1, 0])
+ print(f"Mapping table Mindspore:{src:<30} \t Torch:{tgt:<30} with shape {value.shape}")
+ new_ckpt_list.append({"data": Tensor(value), "name": src})
+ except KeyError:
+ print("keyerror: ", tgt)
+
+
+ return new_ckpt_list
+
+
+def split_torch_attention(state):
+ s = list(state.keys())
+ for name in s:
+ if name.endswith('attn.c_attn.weight') or name.endswith('attn.c_attn.bias'):
+ value = state.pop(name)
+ q, k, v = np.split(value.numpy(), 3, -1)
+ state[name + '.q'] = torch.tensor(q, dtype=value.dtype)
+ state[name + '.k'] = torch.tensor(k)
+ state[name + '.v'] = torch.tensor(v)
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser(description="BERT convert script")
+ parser.add_argument('--layers',
+ type=int,
+ default=1,
+ help="The number of layers of the model to be converted.")
+ parser.add_argument("--torch_path",
+ type=str,
+ default=None,
+ required=True,
+ help="The torch checkpoint path.")
+ parser.add_argument("--mindspore_path",
+ type=str,
+ required=True,
+ default="The output mindspore checkpoint path.",
+ help="Use device nums, default is 128.")
+
+ opt = parser.parse_args()
+ state_dict = torch.load(opt.torch_path, map_location='cpu')
+
+ ms_name = [
+ "bert.bert_encoder.encoder.blocks.{}.attention.dense1.weight",
+ "bert.bert_encoder.encoder.blocks.{}.attention.dense1.bias",
+ "bert.bert_encoder.encoder.blocks.{}.attention.dense2.weight",
+ "bert.bert_encoder.encoder.blocks.{}.attention.dense2.bias",
+ "bert.bert_encoder.encoder.blocks.{}.attention.dense3.weight",
+ "bert.bert_encoder.encoder.blocks.{}.attention.dense3.bias",
+ "bert.bert_encoder.encoder.blocks.{}.attention.projection.weight",
+ "bert.bert_encoder.encoder.blocks.{}.attention.projection.bias",
+ "bert.bert_encoder.encoder.blocks.{}.layernorm2.gamma",
+ "bert.bert_encoder.encoder.blocks.{}.layernorm2.beta",
+ "bert.bert_encoder.encoder.blocks.{}.output.mapping.weight",
+ "bert.bert_encoder.encoder.blocks.{}.output.mapping.bias",
+ "bert.bert_encoder.encoder.blocks.{}.output.projection.weight",
+ "bert.bert_encoder.encoder.blocks.{}.output.projection.bias",
+ "bert.bert_encoder.encoder.blocks.{}.layernorm1.gamma",
+ "bert.bert_encoder.encoder.blocks.{}.layernorm1.beta",
+ ]
+
+ torch_name = [
+ "bert.encoder.layer.{}.attention.self.query.weight",
+ "bert.encoder.layer.{}.attention.self.query.bias",
+ "bert.encoder.layer.{}.attention.self.key.weight",
+ "bert.encoder.layer.{}.attention.self.key.bias",
+ "bert.encoder.layer.{}.attention.self.value.weight",
+ "bert.encoder.layer.{}.attention.self.value.bias",
+ "bert.encoder.layer.{}.attention.output.dense.weight",
+ "bert.encoder.layer.{}.attention.output.dense.bias",
+ "bert.encoder.layer.{}.attention.output.LayerNorm.gamma",
+ "bert.encoder.layer.{}.attention.output.LayerNorm.beta",
+ "bert.encoder.layer.{}.intermediate.dense.weight",
+ "bert.encoder.layer.{}.intermediate.dense.bias",
+ "bert.encoder.layer.{}.output.dense.weight",
+ "bert.encoder.layer.{}.output.dense.bias",
+ "bert.encoder.layer.{}.output.LayerNorm.gamma",
+ "bert.encoder.layer.{}.output.LayerNorm.beta",
+ ]
+
+ addition_mindspore = [
+ "bert.word_embedding.embedding_table",
+ "bert.embedding_postprocessor.full_position_embedding.embedding_table",
+ "bert.embedding_postprocessor.token_type_embedding.embedding_table",
+ "bert.embedding_postprocessor.layernorm.gamma",
+ "bert.embedding_postprocessor.layernorm.beta",
+ "bert.dense.weight",
+ "bert.dense.bias",
+ "bert.mlmloss.dense.weight",
+ "bert.mlmloss.dense.bias",
+ "bert.mlmloss.layernorm.gamma",
+ "bert.mlmloss.layernorm.beta",
+ "bert.mlmloss.vocab_dense.weight",
+ ]
+
+ addition_torch = [
+ "bert.embeddings.word_embeddings.weight",
+ "bert.embeddings.position_embeddings.weight",
+ "bert.embeddings.token_type_embeddings.weight",
+ "bert.embeddings.LayerNorm.gamma",
+ "bert.embeddings.LayerNorm.beta",
+ "bert.pooler.dense.weight",
+ "bert.pooler.dense.bias",
+ "cls.predictions.transform.dense.weight",
+ "cls.predictions.transform.dense.bias",
+ "cls.predictions.transform.LayerNorm.gamma",
+ "cls.predictions.transform.LayerNorm.beta",
+ "cls.predictions.decoder.weight"
+ ]
+
+ mapped_param = generate_params_dict(total_layers=opt.layers,
+ mindspore_params_per_layer=ms_name,
+ torch_params_per_layer=torch_name,
+ mindspore_additional_params=addition_mindspore,
+ torch_additional_params=addition_torch)
+ split_torch_attention(state_dict)
+ new_ckpt = get_converted_ckpt(mapped_param, state_dict)
+ save_checkpoint(new_ckpt, opt.mindspore_path)
+ print(f"Convert finished, the output is saved to {opt.mindspore_path}")
diff --git a/MindSPONGE/applications/research/Geneformer/scripts/main.py b/MindSPONGE/applications/research/Geneformer/scripts/main.py
new file mode 100644
index 0000000000000000000000000000000000000000..f3f8815c0d4b25dc5451c0bd72d9a11cdd552e81
--- /dev/null
+++ b/MindSPONGE/applications/research/Geneformer/scripts/main.py
@@ -0,0 +1,95 @@
+# Copyright 2022 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+# pylint: disable=C0411
+# pylint: disable=C0413
+
+"""main script"""
+import os
+import pickle
+import yaml
+import argparse
+import sys
+import mindspore as ms
+from mindspore.communication import init
+sys.path.append("..")
+from src.classifier import GeneClassifier
+
+
+def parse_args():
+ """args function"""
+ parser = argparse.ArgumentParser(description="run geneformer_classification")
+ parser.add_argument("--bert_config_path", type=str, required=True, help="bert config path")
+ parser.add_argument("--dataset_path", type=str, required=True, help="dataset path")
+ parser.add_argument("--do_train", type=bool, required=False, default=True, help="do train mode")
+ parser.add_argument("--data_parallel", type=bool, required=False, default=False, help="do data parallel")
+ parser.add_argument("--max_ncells", type=int, required=False, default=10_000, help="max ncells")
+ parser.add_argument("--freeze_layers", type=int, required=False, default=4, help="freeze_layers")
+ parser.add_argument("--num_crossval_splits", type=int, required=False, default=5, help="num_crossval_splits")
+ parser.add_argument("--forward_batch_size", type=int, required=False, default=200, help="forward_batch_size")
+ parser.add_argument("--nproc", type=int, required=False, default=16, help="nproc")
+ args = parser.parse_args()
+ return args
+
+def main(main_args):
+ """main"""
+ with open(main_args.dataset_path, 'r') as file:
+ data_config = yaml.safe_load(file)
+ gene_class_dict_path = data_config.get("gene_class_dict_path")
+ dataset_path = data_config.get("dataset_path")
+ output_prefix = data_config.get("output_prefix")
+ output_dir = data_config.get("output_dir")
+ data_output = data_config.get("data_output")
+ model_output = data_config.get("model_output")
+
+ # ensure not overwriting previously saved model
+ ms_model = os.path.join(model_output, "geneformer_mindspore.ckpt")
+ if os.path.isfile(ms_model) is False:
+ raise FileNotFoundError(f"geneformer_mindspore.ckpt not found in {model_output}.")
+ if not os.path.exists(data_output):
+ os.makedirs(data_output)
+ with open(gene_class_dict_path, "rb") as fp:
+ gene_class_dict = pickle.load(fp)
+
+ gc = GeneClassifier(gene_class_dict=gene_class_dict,
+ max_ncells=main_args.max_ncells,
+ freeze_layers=main_args.freeze_layers,
+ num_crossval_splits=main_args.num_crossval_splits,
+ forward_batch_size=main_args.forward_batch_size,
+ nproc=main_args.nproc,
+ config_path=main_args.bert_config_path,
+ do_train=main_args.do_train)
+
+ gc.prepare_data(input_data_file=dataset_path,
+ output_directory=data_output,
+ output_prefix=output_prefix)
+
+ all_metrics = gc.validate(model_directory=model_output,
+ prepared_input_data=f"{data_output}/{output_prefix}_labeled.dataset",
+ id_class_dict_file=f"{data_output}/{output_prefix}_id_class_dict.pkl",
+ output_directory=output_dir,
+ output_prefix=output_prefix)
+
+ print(all_metrics)
+
+if __name__ == '__main__':
+ actual_args = parse_args()
+ if actual_args.data_parallel:
+ ms.set_context(mode=ms.GRAPH_MODE, device_target="Ascend", device_id=int(os.getenv('DEVICE_ID')))
+ ms.set_auto_parallel_context(parallel_mode=ms.ParallelMode.DATA_PARALLEL, gradients_mean=True)
+ init()
+ ms.set_seed(1)
+ main(actual_args)
+ else:
+ main(actual_args)
diff --git a/MindSPONGE/applications/research/Geneformer/scripts/run.sh b/MindSPONGE/applications/research/Geneformer/scripts/run.sh
new file mode 100644
index 0000000000000000000000000000000000000000..7785afc79593d8b8068ceb83e9b3d2d23afafc6c
--- /dev/null
+++ b/MindSPONGE/applications/research/Geneformer/scripts/run.sh
@@ -0,0 +1,61 @@
+#!/bin/bash
+pwd
+# Default file paths
+default_bert_config_path="../config/run_geneformer_args.yaml"
+default_dataset_path="../config/geneformer_config.yaml"
+
+# Function to display help message
+show_help() {
+ echo "Usage: bash run.sh "
+ echo "Usage: bash run.sh $default_bert_config_path $default_dataset_path"
+ echo
+ echo "Options:"
+ echo " --bert_config_path Path to the BERT configuration file (default: $default_bert_config_path)"
+ echo " --dataset_path Path to the dataset configuration file (default: $default_dataset_path)"
+ echo " -h, --help Show this help message"
+ echo
+ echo "This script runs the Python program 'main.py' with the specified configuration files."
+ echo "If no paths are provided, default values will be used."
+}
+
+# Parse command-line arguments
+while [[ $# -gt 0 ]]; do
+ case $1 in
+ --bert_config_path)
+ bert_config_path="$2"
+ shift 2
+ ;;
+ --dataset_path)
+ dataset_path="$2"
+ shift 2
+ ;;
+ -h|--help)
+ show_help
+ exit 0
+ ;;
+ *)
+ echo "Unknown option: $1"
+ show_help
+ exit 1
+ ;;
+ esac
+done
+
+# Check if command-line arguments are provided for file paths, otherwise use defaults
+bert_config_path="${1:-$default_bert_config_path}"
+dataset_path="${2:-$default_dataset_path}"
+
+# Check if the BERT config file exists
+if [ ! -f "$bert_config_path" ]; then
+ echo "Error: BERT config file does not exist: $bert_config_path"
+ exit 1
+fi
+
+# Check if the dataset config file exists
+if [ ! -f "$dataset_path" ]; then
+ echo "Error: Dataset config file does not exist: $dataset_path"
+ exit 1
+fi
+
+# If files exist, proceed with running the Python script
+python main.py --bert_config_path "$bert_config_path" --dataset_path "$dataset_path"
diff --git a/MindSPONGE/applications/research/Geneformer/scripts/run_8p.sh b/MindSPONGE/applications/research/Geneformer/scripts/run_8p.sh
new file mode 100644
index 0000000000000000000000000000000000000000..ba2c1b7b8396864e8d8e5f5d4693bdef01104bea
--- /dev/null
+++ b/MindSPONGE/applications/research/Geneformer/scripts/run_8p.sh
@@ -0,0 +1,73 @@
+#!/bin/bash
+
+# Default file paths
+default_bert_config_path="../config/run_geneformer_args.yaml"
+default_dataset_path="../config/geneformer_config.yaml"
+
+default_RANK_TABLE_FILE="../config/rank_table.json"
+default_RANK_SIZE=8
+
+# Function to display help message
+display_help() {
+ echo "Usage: bash run_8p.sh "
+ echo "Usage: bash run_8p.sh $default_bert_config_path $default_dataset_path $default_RANK_TABLE_FILE $default_RANK_SIZE"
+ echo
+ echo "Options:"
+ echo " -h, --help Show this help message"
+ echo " Path to the BERT configuration file (default: $default_bert_config_path)"
+ echo " Path to the dataset configuration file (default: $default_dataset_path)"
+ echo " Path to the rank table file (default: $default_RANK_TABLE_FILE)"
+ echo " Number of devices for distributed training (default: $default_RANK_SIZE)"
+ echo
+ exit 0
+}
+
+
+# Check if command-line arguments are provided for file paths, otherwise use defaults
+bert_config_path="${1:-$default_bert_config_path}"
+dataset_path="${2:-$default_dataset_path}"
+RANK_TABLE_FILE="${3:-$default_RANK_TABLE_FILE}"
+RANK_SIZE="${4:-$default_RANK_SIZE}"
+
+# Check if help option is provided
+if [[ "$1" == "-h" || "$1" == "--help" ]]; then
+ display_help
+fi
+
+
+# Check if the BERT config file exists
+if [ ! -f "$bert_config_path" ]; then
+ echo "Error: BERT config file does not exist: $bert_config_path"
+ exit 1
+fi
+
+# Check if the dataset config file exists
+if [ ! -f "$dataset_path" ]; then
+ echo "Error: Dataset config file does not exist: $dataset_path"
+ exit 1
+fi
+
+# Check if the dataset config file exists
+if [ ! -f "$RANK_TABLE_FILE" ]; then
+ echo "Error: Dataset config file does not exist: $RANK_TABLE_FILE"
+ exit 1
+fi
+
+export RANK_TABLE_FILE=$default_RANK_TABLE_FILE
+export RANK_SIZE=$default_RANK_SIZE
+
+DIR_NAME="runlog"
+if [ -d "$DIR_NAME" ]; then
+ echo "Directory $DIR_NAME already exists"
+else
+ mkdir "$DIR_NAME"
+ echo "Directory $DIR_NAME has been created"
+fi
+for((i=0;i<${RANK_SIZE};i++))
+do
+ export DEVICE_ID=$i
+ export RANK_ID=$i
+ echo "start training for device $i"
+ env > env$i.log
+ python3 ./main.py --bert_config_path "$bert_config_path" --dataset_path "$dataset_path" --nproc 16 --data_parallel True > $DIR_NAME/train_device$i.log 2>&1 &
+done
diff --git a/MindSPONGE/applications/research/Geneformer/src/__init__.py b/MindSPONGE/applications/research/Geneformer/src/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..024564c568c9c7675185fe4eef9650aa897e8cf2
--- /dev/null
+++ b/MindSPONGE/applications/research/Geneformer/src/__init__.py
@@ -0,0 +1,25 @@
+# Copyright 2023 Huawei Technologies Co., Ltd & CPL YiQin GAO Research Group
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""init"""
+import warnings
+from pathlib import Path
+
+# .pkl files download from https://huggingface.co/ctheodoris/Geneformer/tree/main/geneformer
+
+GENE_MEDIAN_FILE = Path(__file__).parent / "gene_median_dictionary_gc95M.pkl"
+TOKEN_DICTIONARY_FILE = Path(__file__).parent / "token_dictionary_gc95M.pkl"
+ENSEMBL_DICTIONARY_FILE = Path(__file__).parent / "gene_name_id_dict_gc95M.pkl"
+ENSEMBL_MAPPING_FILE = Path(__file__).parent / "ensembl_mapping_dict_gc95M.pkl"
diff --git a/MindSPONGE/applications/research/Geneformer/src/classifier.py b/MindSPONGE/applications/research/Geneformer/src/classifier.py
new file mode 100644
index 0000000000000000000000000000000000000000..a4ad4596f2693347ecb7bcd7da38d9b6ece333ef
--- /dev/null
+++ b/MindSPONGE/applications/research/Geneformer/src/classifier.py
@@ -0,0 +1,402 @@
+# Copyright 2023 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+# pylint: disable=C0411
+# pylint: disable=C0413
+
+"""classifier script"""
+import os
+import time
+import datetime
+import logging
+import pickle
+import subprocess
+from pathlib import Path
+import numpy as np
+from sklearn.metrics import f1_score, accuracy_score
+from tqdm.auto import tqdm
+
+import mindspore
+from mindspore import Tensor, nn
+from mindspore.ops import operations as P
+import mindspore.dataset as ds
+from mindformers.tools.register import MindFormerConfig
+from mindformers import AutoConfig, Trainer, BertForTokenClassification
+
+from . import TOKEN_DICTIONARY_FILE
+from . import perturber_utils as pu
+from . import classifier_utils as cu
+from . import preparedata as pre
+logger = logging.getLogger(__name__)
+
+
+def pre_label(data):
+ logits = Tensor(data, mindspore.float16)
+ softmax = P.Softmax()
+ probabilities = softmax(logits)
+ predicted_label = probabilities.argmax()
+ predicted_label = predicted_label.asnumpy().item()
+ return predicted_label
+
+
+class GeneClassifier(nn.Cell):
+ """GeneClassifier"""
+ def __init__(
+ self,
+ quantize=False,
+ gene_class_dict=None,
+ filter_data=None,
+ rare_threshold=0,
+ max_ncells=None,
+ max_ncells_per_class=None,
+ training_args=None,
+ ray_config=None,
+ freeze_layers=0,
+ num_crossval_splits=1,
+ train_size = 0.8,
+ valid_size = 0.1,
+ test_size = 0.1,
+ stratify_splits_col=None,
+ no_eval=False,
+ forward_batch_size=100,
+ nproc=4,
+ config_path=None,
+ do_train=False
+ ):
+ """
+ validate cell state or gene classifier.
+
+ Args:
+ quantize(bool): Whether to fine-tune a quantized model..
+ gene_class_dict(dict): Gene classes to fine-tune model to distinguish.
+ filter_data(dict): Otherwise, dictionary specifying .dataset column name and list of values to filter by.
+ rare_threshold(int): Threshold below which rare cell states should be removed.
+ max_ncells(int): Maximum number of cells to use for fine-tuning.
+ max_ncells_per_class(int): Maximum number of cells per cell class to use for fine-tuning.
+ training_args(dict): Training arguments for fine-tuning.
+ ray_config(dict): Training argument ranges for tuning hyperparameters with Ray.
+ freeze_layers(int): Number of layers to freeze from fine-tuning.
+ split_sizes(dict): Dictionary of proportion of data to hold out for train, validation, and test sets
+ stratify_splits_col(dict): Proportion of each class in this column will
+ be the same in the splits as in the original dataset.
+ no_eval(bool): Will skip eval step and use all data for training.
+ forward_batch_size(str): Batch size for forward pass (for evaluation, not training).
+ nproc=(int): Number of CPU processes to use.
+ Returns:
+ GeneClassifier object
+ """
+ super(GeneClassifier, self).__init__()
+ self.quantize = quantize
+ self.gene_class_dict = gene_class_dict
+ self.filter_data = filter_data
+ self.rare_threshold = rare_threshold
+ self.max_ncells = max_ncells
+ self.max_ncells_per_class = max_ncells_per_class
+ self.training_args = training_args
+ self.ray_config = ray_config
+ self.freeze_layers = freeze_layers
+ self.num_crossval_splits = num_crossval_splits
+ self.train_size = train_size
+ self.valid_size = valid_size
+ self.oos_test_size = test_size
+ self.eval_size = self.valid_size / (self.train_size + self.valid_size)
+ self.stratify_splits_col = stratify_splits_col
+ self.no_eval = no_eval
+ self.forward_batch_size = forward_batch_size
+ self.nproc = nproc
+ self.config_path = config_path
+ self.do_train = do_train
+
+ # load token dictionary (Ensembl IDs:token)
+ with open(TOKEN_DICTIONARY_FILE, "rb") as f:
+ self.gene_token_dict = pickle.load(f)
+ self.token_gene_dict = {v: k for k, v in self.gene_token_dict.items()}
+ self.gene_class_dict = {
+ k: list(set([self.gene_token_dict.get(gene) for gene in v]))
+ for k, v in self.gene_class_dict.items()
+ }
+ empty_classes = []
+ for k, v in self.gene_class_dict.items():
+ if v is not None:
+ empty_classes += [k]
+
+ def prepare_data(
+ self,
+ input_data_file,
+ output_directory,
+ output_prefix,
+ ):
+ """
+ prepare_data
+
+ Args:
+ input_data_file(path): Path to directory containing .dataset input.
+ output_directory(path): Path to directory where prepared data will be saved.
+ output_prefix(str): Prefix for output file.
+ balanced other attributes.
+ Returns:
+ Save processed data to dir.
+ """
+ data = pu.load_and_filter(self.filter_data, self.nproc, input_data_file)
+ data, id_class_dict = cu.label_classes(
+ data, self.gene_class_dict, self.nproc
+ )
+ id_class_output_path = (
+ Path(output_directory) / f"{output_prefix}_id_class_dict"
+ ).with_suffix(".pkl")
+ with open(id_class_output_path, "wb") as f:
+ pickle.dump(id_class_dict, f)
+ data_output_path = (
+ Path(output_directory) / f"{output_prefix}_labeled"
+ ).with_suffix(".dataset")
+ data.save_to_disk(str(data_output_path))
+
+
+ def evaluate_model(
+ self,
+ model_directory,
+ eval_data,
+ eval_batch_size=16,
+ max_len=2048,
+ mask_label=-100,
+ ):
+ """
+ evaluate_model
+
+ Args:
+ model_directory(dict): Path to directory where eval data will be saved.
+ eval_data(dataset): Loaded evaluation .dataset input.
+ eval_batch_size(int): Batch size to eval.
+ max_len(int): Max len of data.
+ mask_label(int): Mask numbers.
+ Returns:
+ A value of acc
+ """
+ token_type_ids = np.zeros((eval_batch_size, max_len), dtype=np.int32)
+ token_type_ids = Tensor(token_type_ids, dtype=mindspore.int32)
+ dataset = ds.GeneratorDataset(pre.generator_eval_data(eval_data, max_len),
+ column_names=['input_ids', 'length', 'labels'])
+ dataset = dataset.batch(eval_batch_size)
+ geneformer_config = AutoConfig.from_pretrained(self.config_path)
+ checkpoint_dir = os.path.join(model_directory, "checkpoint/rank_0/")
+
+ if not os.path.exists(checkpoint_dir):
+ raise FileNotFoundError(checkpoint_dir + " dir not found")
+ max_ckpt = pu.find_latest_ckpt(checkpoint_dir)
+ print(f"use ckpt file: {max_ckpt}")
+ geneformer_config.load_checkpoint = max_ckpt
+ geneformer_config.checkpoint_name_or_path = max_ckpt
+ eval_model = BertForTokenClassification(geneformer_config)
+ eval_model.set_train(False)
+ logit_list = []
+ true_list = []
+ eval_start_time = time.time()
+ for i, data in enumerate(dataset):
+ input_data, label, mask = data
+ input_data = input_data.asnumpy().tolist()
+ input_data = Tensor(input_data, dtype=mindspore.int32)
+ mask = Tensor(mask, dtype=mindspore.int32)
+ output = eval_model(input_data, mask, token_type_ids)
+ logit_list.extend(output.asnumpy())
+ true_list.extend(label.asnumpy())
+ logit_label_paired = [
+ (logit, label)
+ for batch_logit, batch_labels in zip(logit_list, true_list)
+ for logit, label in zip(batch_logit, batch_labels)
+ if label != mask_label
+ ]
+ pre_list = []
+ for index, data in enumerate(logit_label_paired):
+ result = pre_label(data[0])
+ pre_list.append(result)
+ label_true = [item[1] for item in logit_label_paired]
+ f1 = f1_score(label_true, pre_list, average='binary')
+ acc = accuracy_score(label_true, pre_list)
+ eval_end_time = time.time()
+ print("eval_time: ", eval_end_time - eval_start_time)
+ print(f"F1 Score: {f1}")
+ print(f"Accuracy: {acc}")
+ return acc
+
+ def train_classifier(
+ self,
+ model_directory,
+ train_data,
+ eval_data,
+ train_batch_size=12,
+ ):
+
+ """
+ Fine-tune model for cell state or gene classification.
+
+ Args:
+ model_directory(dict): Path to directory containing model.
+ num_classes(int): Number of classes for classifier.
+ train_data(dataset): Loaded training .dataset input.
+ eval_data(dataset): Loaded evaluation .dataset input.
+ config_path(path): Model config path.
+ train_batch_size(int): Batch size to train.
+ Returns:
+ A trainer object
+ """
+
+ # Validate and prepare data
+ train_data, eval_data = cu.validate_and_clean_cols(
+ train_data, eval_data
+ )
+ # Load model and training args
+ model = pu.load_model(
+ model_directory,
+ "train",
+ quantize=self.quantize,
+ config_path=self.config_path
+ )
+ def_freeze_layers = self.freeze_layers
+ if def_freeze_layers > 0:
+ for param in model.get_parameters(expand=True):
+ if param.name.startswith("bert.bert_encoder.encoder") and int(
+ param.name.split("bert.bert_encoder.encoder.blocks.")[1].split(".")[0]) < def_freeze_layers:
+ param.requires_grad = False
+
+ my_data = pre.GeneratorTrainData(train_data)
+ dataset = ds.GeneratorDataset(source=my_data,
+ column_names=["input_ids", "input_mask", 'segment_ids', 'label_ids'])
+ dataset = dataset.batch(train_batch_size)
+ config = MindFormerConfig(self.config_path)
+ config.output_dir = model_directory
+ trainer = Trainer(task='token_classification',
+ model=model,
+ args=config,
+ train_dataset=dataset,
+ eval_dataset=dataset)
+ if self.do_train:
+ trainer.train()
+ return trainer
+
+ def validate(
+ self,
+ model_directory,
+ prepared_input_data,
+ id_class_dict_file,
+ output_directory,
+ output_prefix,
+ gene_balance=False,
+ predict_trainer=False,
+ save_split_datasets=True,
+ n_splits=2
+ ):
+ """
+ (Cross-)validate cell state or gene classifier.
+
+ Args:
+ model_directory(path) to directory containing model.
+ prepared_input_data(path): Path to directory containing _labeled.dataset
+ previously prepared by Classifier.prepare_data.
+ id_class_dict_file(path): Path to _id_class_dict.pkl previously prepared by Classifier.prepare_data.
+ output_directory(path): Path to directory where model and eval data will be saved.
+ output_prefix(str): Prefix for output files.
+ gene_balance(bool): Whether to automatically balance genes in training set.
+ predict_eval(bool): Whether or not to save eval predictions.
+ predict_trainer(bool): Whether or not to save eval predictions from trainer.
+ n_hyperopt_trials(int): Number of trials to run for hyperparameter optimization.
+ save_split_datasets(bool): Whether or not to save train, valid, and test gene-labeled datasets.
+ eval_batch_size(int): Batch size to eval.
+ n_splits(int): Eval split number.
+ Returns:
+ A list of ACC.
+ """
+ # load numerical id to class dictionary (id:class)
+ with open(id_class_dict_file, "rb") as f:
+ id_class_dict = pickle.load(f)
+ class_id_dict = {v: k for k, v in id_class_dict.items()}
+ # load previously filtered and prepared data
+ data = pu.load_and_filter(None, self.nproc, prepared_input_data)
+ data = data.shuffle(seed=42) # reshuffle in case users provide unshuffled data
+
+ # define output directory path
+ current_date = datetime.datetime.now()
+ datestamp = f"{str(current_date.year)[-2:]}{current_date.month:02d}{current_date.day:02d}"
+ if output_directory[-1:] != "/": # add slash for dir if not present
+ output_directory = output_directory + "/"
+ output_dir = f"{output_directory}{datestamp}_geneformer_geneClassifier_{output_prefix}/"
+ subprocess.call(f"mkdir {output_dir}", shell=True)
+
+ # get number of classes for classifier
+ num_classes = cu.get_num_classes(id_class_dict)
+ iteration_num = 1
+ targets = pu.flatten_list(self.gene_class_dict.values())
+ labels = pu.flatten_list(
+ [
+ [class_id_dict[label]] * len(targets)
+ for label, targets in self.gene_class_dict.items()
+ ]
+ )
+ skf = cu.StratifiedKFold3(n_splits=n_splits, random_state=0, shuffle=True)
+ test_ratio = self.oos_test_size / (self.eval_size + self.oos_test_size)
+ result_list = []
+ for train_index, eval_index, test_index in tqdm(
+ skf.split(targets, labels, test_ratio)
+ ):
+ train_data, eval_data = cu.gene_split_data(
+ data,
+ targets,
+ labels,
+ train_index,
+ eval_index,
+ self.max_ncells,
+ iteration_num,
+ self.nproc,
+ gene_balance,
+ )
+
+ if save_split_datasets is True:
+ for split_name in ["train", "valid"]:
+ labeled_dataset_output_path = (
+ Path(output_dir)
+ / f"{output_prefix}_{split_name}_gene_labeled_ksplit{iteration_num}"
+ ).with_suffix(".dataset")
+ if split_name == "train":
+ train_data.save_to_disk(str(labeled_dataset_output_path))
+ elif split_name == "valid":
+ eval_data.save_to_disk(str(labeled_dataset_output_path))
+
+ if self.oos_test_size > 0:
+ test_data = cu.gene_classifier_split(
+ data,
+ targets,
+ labels,
+ test_index,
+ "test",
+ self.max_ncells,
+ iteration_num,
+ self.nproc,
+ )
+ if save_split_datasets is True:
+ test_labeled_dataset_output_path = (
+ Path(output_dir)
+ / f"{output_prefix}_test_gene_labeled_ksplit{iteration_num}"
+ ).with_suffix(".dataset")
+ test_data.save_to_disk(str(test_labeled_dataset_output_path))
+ self.train_classifier(
+ model_directory,
+ train_data,
+ eval_data,
+ predict_trainer,
+ )
+ result = self.evaluate_model(
+ model_directory,
+ eval_data,
+ )
+ result_list.append(result)
+ return result_list
diff --git a/MindSPONGE/applications/research/Geneformer/src/classifier_utils.py b/MindSPONGE/applications/research/Geneformer/src/classifier_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..a0145e291a727a5b0b5667a39ff05b4d47e1b853
--- /dev/null
+++ b/MindSPONGE/applications/research/Geneformer/src/classifier_utils.py
@@ -0,0 +1,223 @@
+# Copyright 2023 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""classifier_utils script"""
+import random
+import logging
+import numpy as np
+from sklearn.model_selection import StratifiedKFold, train_test_split
+
+from . import perturber_utils as pu
+logger = logging.getLogger(__name__)
+
+
+def label_classes(data, gene_class_dict, nproc):
+ """remove cells without any of the target genes"""
+ def if_contains_label(example):
+ a = pu.flatten_list(gene_class_dict.values())
+ b = example["input_ids"]
+ return not set(a).isdisjoint(b)
+ data = data.filter(if_contains_label, num_proc=nproc)
+ if data is None:
+ logger.error(
+ "No cells remain after filtering for target genes. Check target gene list."
+ )
+ label_set = gene_class_dict.keys()
+ class_id_dict = dict(zip(label_set, [i for i in range(len(label_set))]))
+ id_class_dict = {v: k for k, v in class_id_dict.items()}
+
+ def classes_to_ids(example):
+ example["labels"] = label_gene_classes(
+ example, class_id_dict, gene_class_dict
+ )
+ return example
+ data = data.map(classes_to_ids, num_proc=nproc)
+ return data, id_class_dict
+
+
+def downsample_and_shuffle(data, max_ncells, max_ncells_per_class, cell_state_dict):
+ """downsample and shuffle datasets"""
+ data = data.shuffle(seed=42)
+ num_cells = len(data)
+ # if max number of cells is defined, then subsample to this max number
+ if max_ncells:
+ if num_cells > max_ncells:
+ data = data.select([i for i in range(max_ncells)])
+ if max_ncells_per_class:
+ class_labels = data[cell_state_dict["state_key"]]
+ random.seed(42)
+ subsample_indices = subsample_by_class(class_labels, max_ncells_per_class)
+ data = data.select(subsample_indices)
+ return data
+
+
+def remove_cols(data, cols_to_keep):
+ """remove and cols datasets"""
+ other_cols = list(data.features.keys())
+ other_cols = [ele for ele in other_cols if ele not in cols_to_keep]
+ data = data.remove_columns(other_cols)
+ return data
+
+
+def validate_and_clean_cols(train_data, eval_data):
+ """validate and clean cols datasets"""
+ # validate that data has expected label column and remove others
+ label_col = "labels"
+
+ cols_to_keep = [label_col] + ["input_ids", "length"]
+ if label_col not in train_data.column_names:
+ logger.error("train_data must contain column %s with class labels.", label_col)
+ else:
+ train_data = remove_cols(train_data, cols_to_keep)
+
+ if eval_data:
+ if label_col not in eval_data.column_names:
+ logger.error(
+ "eval_data must contain column %s with class labels.", label_col
+ )
+ else:
+ eval_data = remove_cols(eval_data, cols_to_keep)
+ return train_data, eval_data
+
+
+def label_gene_classes(example, class_id_dict, gene_class_dict):
+ """label gene classes"""
+ return [
+ class_id_dict.get(gene_class_dict.get(token_id, -100), -100)
+ for token_id in example["input_ids"]
+ ]
+
+
+def get_num_classes(id_class_dict):
+ """get classes num"""
+ return len(set(id_class_dict.values()))
+
+
+def gene_split_data(
+ data,
+ targets,
+ labels,
+ train_index,
+ eval_index,
+ max_ncells,
+ iteration_num,
+ num_proc,
+ balance=False,
+):
+ """split gene data"""
+ # generate cross-validation splits
+ train_data = gene_classifier_split(
+ data,
+ targets,
+ labels,
+ train_index,
+ "train",
+ max_ncells,
+ iteration_num,
+ num_proc,
+ balance,
+ )
+ eval_data = gene_classifier_split(
+ data,
+ targets,
+ labels,
+ eval_index,
+ "eval",
+ max_ncells,
+ iteration_num,
+ num_proc,
+ balance,
+ )
+ return train_data, eval_data
+
+
+def gene_classifier_split(
+ data,
+ targets,
+ labels,
+ index,
+ subset_name,
+ max_ncells,
+ iteration_num,
+ num_proc,
+ balance=False,
+):
+ """split gene classifier"""
+ # generate cross-validation splits
+ targets = np.array(targets)
+ labels = np.array(labels)
+ targets_subset = targets[index]
+ labels_subset = labels[index]
+ label_dict_subset = dict(zip(targets_subset, labels_subset))
+
+ # function to filter by whether contains train or eval labels
+ def if_contains_subset_label(example):
+ a = targets_subset
+ b = example["input_ids"]
+ return not set(a).isdisjoint(b)
+
+ # filter dataset for examples containing classes for this split
+ logger.info("Filtering data for %s genes in split %d", subset_name, iteration_num)
+ subset_data = data.filter(if_contains_subset_label, num_proc=num_proc)
+ percentage_filtered = round((1 - len(subset_data) / len(data)) * 100)
+ logger.info(
+ "Filtered %d%%; %d remain\n", percentage_filtered, len(subset_data)
+ )
+
+ # balance gene subsets if train
+ if (subset_name == "train") and (balance is True):
+ subset_data, label_dict_subset = balance_gene_split(
+ subset_data, label_dict_subset, num_proc
+ )
+
+ # subsample to max_ncells
+ subset_data = downsample_and_shuffle(subset_data, max_ncells, None, None)
+
+ # relabel genes for this split
+ def subset_classes_to_ids(example):
+ example["labels"] = [
+ label_dict_subset.get(token_id, -100) for token_id in example["input_ids"]
+ ]
+ return example
+
+ subset_data = subset_data.map(subset_classes_to_ids, num_proc=num_proc)
+ return subset_data
+
+
+class StratifiedKFold3(StratifiedKFold):
+ """StratifiedKFold3"""
+ def split(self, targets, labels, test_ratio=0.5, groups=None):
+ """split"""
+ s = super().split(targets, labels, groups)
+ for train_indxs, test_indxs in s:
+ if test_ratio == 0:
+ yield train_indxs, test_indxs, None
+ else:
+ labels_test = np.array(labels)[test_indxs]
+ valid_indxs, test_indxs = train_test_split(
+ test_indxs,
+ stratify=labels_test,
+ test_size=test_ratio,
+ random_state=0,
+ )
+ yield train_indxs, valid_indxs, test_indxs
+
+
+def load_data(dataset):
+ """load data"""
+ input_ids = dataset['input_ids']
+ length = dataset['length']
+ labels = dataset['labels']
+ return input_ids, length, labels
diff --git a/MindSPONGE/applications/research/Geneformer/src/perturber_utils.py b/MindSPONGE/applications/research/Geneformer/src/perturber_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..87534fa319f407003fb224b92b43323c3e954d25
--- /dev/null
+++ b/MindSPONGE/applications/research/Geneformer/src/perturber_utils.py
@@ -0,0 +1,99 @@
+# Copyright 2023 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""perturber utiles script"""
+import os
+import re
+import glob
+from multiprocessing import Pool
+from datasets import load_from_disk
+
+from mindformers import AutoConfig, BertForTokenClassification
+
+
+def flatten_list(megalist):
+ """flatten list"""
+ return [item for sublist in megalist for item in sublist]
+
+
+def parse_filename(filename):
+ """parse filename"""
+ match = re.search(r"mindformers_rank_(\d+)-(\d+)_(\d+)\.ckpt", filename)
+ if match:
+ return int(match.group(1)), int(match.group(2)), int(match.group(3))
+ return None
+
+
+def filter_data_by_criteria(example, criteria):
+ """filter data by criteria"""
+ return example[criteria['key']] in criteria['value']
+
+
+def filter_by_dict(data, filter_data, nproc):
+ """filter by dict"""
+ criteria_list = [{'key': key, 'value': value} for key, value in filter_data.items()]
+ with Pool(nproc) as pool:
+ results = pool.starmap(
+ filter_data_by_criteria,
+ [(example, criteria) for example in data for criteria in criteria_list]
+ )
+ filtered_results = [all(result) for result in zip(*results)]
+ data = data[filtered_results]
+ return data
+
+
+def load_and_filter(filter_data, nproc, input_data_file):
+ """load and filter"""
+ data = load_from_disk(input_data_file)
+ if filter_data:
+ data = filter_by_dict(data, filter_data, nproc)
+ return data
+
+
+def quant_layers(model):
+ """quant layers"""
+ layer_nums = []
+ for name, _ in model.parameters_and_names():
+ if name.endswith(".attention.projection.weight"):
+ layer_nums.append(int(name.split(".attention.projection.weight")[0].split(".")[-1]))
+ return max(layer_nums) + 1
+
+
+def find_latest_ckpt(directory):
+ """find latest ckpt"""
+ ckpt_files = glob.glob(os.path.join(directory, '*.ckpt'))
+ if not ckpt_files:
+ return None
+ latest_ckpt = None
+ latest_time = 0
+ for ckpt_file in ckpt_files:
+ creation_time = os.path.getctime(ckpt_file)
+ if creation_time > latest_time:
+ latest_time = creation_time
+ latest_ckpt = ckpt_file
+ return latest_ckpt
+
+
+def load_model(model_directory, mode, config_path="config/run_geneformer_args.yaml"):
+ """load model weights"""
+ geneformer_config = AutoConfig.from_pretrained(config_path)
+ if not os.path.exists(os.path.join(model_directory, "geneformer_mindspore.ckpt")):
+ raise FileNotFoundError(os.path.join(model_directory, "geneformer_mindspore.ckpt") + " not found")
+ geneformer_config.load_checkpoint = os.path.join(model_directory, "geneformer_mindspore.ckpt")
+ geneformer_config.checkpoint_name_or_path = os.path.join(model_directory, "geneformer_mindspore.ckpt")
+ model = BertForTokenClassification(geneformer_config)
+ if mode == "eval":
+ model.eval()
+ return model
diff --git a/MindSPONGE/applications/research/Geneformer/src/preparedata.py b/MindSPONGE/applications/research/Geneformer/src/preparedata.py
new file mode 100644
index 0000000000000000000000000000000000000000..8640a36e2168ec421e0789776a97f2181c36b376
--- /dev/null
+++ b/MindSPONGE/applications/research/Geneformer/src/preparedata.py
@@ -0,0 +1,72 @@
+# Copyright 2023 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""prepare data script"""
+import numpy as np
+
+
+def convert_eval_data(data, max_len):
+ """convert eval data"""
+ np.set_printoptions(threshold=np.inf)
+ mask_id = np.ones(len(data["input_ids"]), dtype=int)
+ mask_id = np.pad(mask_id, (0, max_len - len(data["input_ids"])), mode="constant", constant_values=0)
+ data["input_ids"] = np.pad(data["input_ids"], (0, max_len - len(data["input_ids"])),
+ mode="constant", constant_values=0)
+ data["labels"] = np.pad(data["labels"], (0, max_len - len(data["labels"])), mode="constant", constant_values=-100)
+ return data["input_ids"], data["labels"], mask_id
+
+
+def generator_eval_data(eval_data, max_len):
+ """generator eval data"""
+ for data in eval_data:
+ yield convert_eval_data(data, max_len)
+
+
+def load_data(dataset):
+ """load data"""
+ input_ids = dataset['input_ids']
+ length = dataset['length']
+ labels = dataset['labels']
+ return input_ids, length, labels
+
+
+class GeneratorTrainData:
+ """GeneratorTrainData"""
+ def __init__(self, dataset):
+ input_ids, length, labels = load_data(dataset)
+ self._input_ids = input_ids
+ self._length = length
+ self._labels = labels
+
+ def convert_data(self, index, max_len=2048):
+ """convert_data"""
+ input_ids = self._input_ids[index]
+ labels = self._labels[index]
+ mask_id = np.ones(len(input_ids), dtype=int)
+ mask_id = np.pad(mask_id, (0, max_len - len(input_ids)), mode="constant", constant_values=0)
+ mask_id = mask_id.astype(np.int32)
+ data_input_ids = np.pad(input_ids, (0, max_len - len(input_ids)), mode="constant", constant_values=0)
+ data_input_ids = data_input_ids.astype(np.int32)
+ data_labels = np.pad(labels, (0, max_len - len(labels)), mode="constant", constant_values=-100)
+ data_labels = data_labels.astype(np.int32)
+ token_type_ids = np.zeros((max_len), dtype=np.int32)
+ return data_input_ids, mask_id, token_type_ids, data_labels
+
+ def __getitem__(self, index):
+ data_input_ids, mask_id, token_type_ids, data_labels = self.convert_data(index)
+ return data_input_ids, mask_id, token_type_ids, data_labels
+
+ def __len__(self):
+ return len(self._input_ids)