diff --git a/scBERT/README_CN.md b/scBERT/README_CN.md new file mode 100644 index 0000000000000000000000000000000000000000..0ff47332ea088cca0e79d0aef9deed264c1b5bdc --- /dev/null +++ b/scBERT/README_CN.md @@ -0,0 +1,300 @@ +# 目录 + +- [目录](#目录) +- [scBERT描述](#scBERT描述) +- [模型架构](#模型架构) +- [数据集](#数据集) +- [特性](#特性) +- [环境要求](#环境要求) +- [快速入门](#快速入门) +- [脚本说明](#脚本说明) + - [脚本及样例代码](#脚本及样例代码) + - [脚本参数](#脚本参数) + - [训练过程](#训练过程) + - [单卡训练](#单卡训练) + - [分布式训练](#分布式训练) + - [微调过程](#微调过程) + - [微调](#微调) + - [python命令启动](#python命令启动) + - [shell脚本启动](#shell脚本启动) + - [推理过程](#推理过程) + - [用法](#用法) + - [结果](#结果) +- [随机情况说明](#随机情况说明) +- [ModelZoo主页](#modelzoo主页) + +# scBERT描述 + +细胞类型的可靠注释是单细胞RNA测序数据下游分析的前提条件。现有的注释算法通常面临批次效应处理不当、缺乏精心筛选的标记基因列表,或难以利用基因-基因之间潜在相互作用信息等问题。受大规模预训练语言模型的启发,我们提出了一种基于预训练深度神经网络的模型scBERT(单细胞双向编码器表示转换模型),以克服上述挑战。scBERT采用了深度学习领域的预训练和微调的最新范式。在scBERT的第一个阶段,它通过在大量未标记的scRNA-seq数据上进行预训练,获得了对基因-基因相互作用的广泛理解。然后,经过预训练的scBERT可以通过监督微调用于对未见的和特定用户的scRNA-seq数据进行细胞注释任务。更多信息请参考 [https://www.biorxiv.org/content/10.1101/2021.12.05.471261v1](https://www.biorxiv.org/content/10.1101/2021.12.05.471261v1)。 + +# 模型架构 + +本模型主要包含以下组件: + +1. Gene2vec位置编码模块: + - 使用预训练的基因向量进行编码 + - 维度: 200 + +2. Performer核心编码器: + - 多头注意力机制: 10个头 + - 前馈网络层 + - Layer Normalization层 + - Dropout正则化 + +3. 预训练任务: + - 掩码语言建模(MLM) + - 掩码概率: 0.15 + - 替换概率: 0.9 + +4. 下游分类头: + - 全连接层 + - ReLU激活 + - Dropout层 + +# 数据集 + +使用的数据集:[panglao_10000 Zheng68k](https://drive.weixin.qq.com/s?k=AJEAIQdfAAozQt5B8k) + +1. 预训练数据集: +- 名称: Panglao scRNA-seq数据集 +- 格式: H5AD文件 +- 路径: ./data/panglao_10000.h5ad +- 特征数: 16906个基因 +- 数据大小: 99.3MB + +2. 微调数据集: +- 名称: Zheng68k scRNA-seq数据集 +- 格式: H5AD文件 +- 路径: ./data/Zheng68k_prepeocessed.h5ad +- 特征数: 17053个基因 +- 细胞类型数: 11 +- 数据大小: 262MB + +支持的数据集:panglao_10000 Zheng68k 或者与 AnnData 格式相同的数据集 + +- 目录结构如下,由用户定义目录和文件的名称 + +![image](demo/predict-demo.jpg) + +- 如果用户需要自定义数据集,则需要将数据集格式转化为AnnData数据格式。 + +# 特性 + +1. 分布式训练支持 + - 数据并行(Data Parallel) + - 流水线并行(Pipeline Parallel) + +2. 动态学习率 + - 使用ExponentialDecayLR进行学习率调整 + +3. 混合精度训练 + - 支持FP16训练 + +4. 昇腾硬件适配 + - 支持昇腾910训练推理 + +# 环境要求 + +- 硬件(Ascend) + - 使用Ascend处理器来搭建硬件环境。 +- 框架 + - [MindSpore](https://www.mindspore.cn/install) +- 如需查看详情,请参见如下资源 + - [MindSpore教程](https://www.mindspore.cn/tutorials/zh-CN/master/index.html) + - [MindSpore Python API](https://www.mindspore.cn/docs/zh-CN/master/api_python/mindspore.html) + +# 快速入门 + +- 通过官方网站安装Mindspore后,您可以按照如下步骤进行预训练和微调 + +```shell +# 单卡训练 +python pretrain.py \ + --data_path='./data/panglao_10000.h5ad' \ + --epoch=100 \ + --batch_size=4 \ + --learning_rate=1e-4 +``` + +```shell +# 通过shell脚本进行8卡训练 +bash run_distribute_pretrain.sh +``` + +```shell +# 单卡微调 +python finetune.py \ + --data_path='./data/Zheng68k_prepeocessed.h5ad' \ + --model_path='./ckpt/ckpt-0.ckpt' \ + --epoch=100 \ + --batch_size=1 \ + --learning_rate=1e-4 +``` + +```shell +# 多卡微调 +bash run_distribute_finetune.sh +``` + +# 脚本说明 + +## 脚本及样例代码 + +```text + |----README_CN.md + |----ckpt + |----data + |----demo + | |----AnnData.png + |----run_distribute_pretrain.sh + |----run_distribute_finetune.sh + |----pretrain.py + |----finetune.py + |----performer.py + |----dataset_pretrain.py + |----dataset_finetune.py + |----layers.py + |----utils.py +``` + +## 脚本参数 + +train.py中主要的参数如下: + +```text + +--enable_pipeline 是否启用流水线并行,默认为True +--device_id 设备ID +--bin_num 分箱数量,默认值:5 +--gene_num 基因数量,默认值:16906 +--epoch 训练轮数,默认值:100 +--seed 随机种子,默认值:2021 +--batch_size 批次大小,默认值:4 +--learning_rate 学习率,默认值:1e-4 +--valid_every 验证间隔,默认值:1 +--mask_prob 掩码概率,默认值:0.15 +--replace_prob 替换概率,默认值:0.9 +--pos_embed 是否使用Gene2vec编码,默认为True +--data_path 数据路径 +--model_name 模型名称,默认为panglao_pretrain +``` + +## 训练过程 + +### 单卡训练 + +在Ascend设备上,使用python脚本直接开始训练(单卡) + + python命令启动 + + ```shell + # 单卡训练 + python pretrain.py --device_id 0 + ``` + +### 分布式训练 + +在Ascend设备上,使用shell脚本执行分布式训练示例(8卡) + +```shell +# 通过shell脚本进行8卡训练 +bash run_distribute_pretrain.sh +``` + +```log + + 上述shell脚本将在后台运行分布式训练。您可以通过training_log.txt文件查看结果。得到如下损失值: + + ```log + + ... + == Epoch: 1 | Training Loss: 0.029950 | Accuracy: 17.1237% == + == Epoch: 1 | Validation Loss: 1.671589 | Accuracy: 0.0000% == + == Epoch: 2 | Training Loss: 0.022785 | Accuracy: 32.4212% == + == Epoch: 2 | Validation Loss: 1.253894 | Accuracy: 3.1250% == + == Epoch: 3 | Training Loss: 0.017635 | Accuracy: 61.4334% == + == Epoch: 3 | Validation Loss: 0.898995 | Accuracy: 75.6098% == + ... + +``` + +## 微调过程 + +### 微调 + +#### python命令启动 + +```shell + +python finetune.py \ + --model_path='./ckpt/pretrain-99.ckpt' \ + --data_path='./data/Zheng68k_prepeocessed.h5ad' + +``` + +#### shell脚本启动 + +```shell + +bash run_distribute_finetune.sh + +``` + +```text + + == Epoch: 1 | Training Loss: 2.027127 | Accuracy: 28.5007% == + == Epoch: 1 | Validation Loss: 1.894380 | Accuracy: 0.300657 == + == Epoch: 2 | Training Loss: 1.293512 | Accuracy: 54.2020% == + == Epoch: 2 | Validation Loss: 0.852387 | Accuracy: 0.695179 == + == Epoch: 3 | Training Loss: 0.617621 | Accuracy: 78.1191% == + == Epoch: 3 | Validation Loss: 0.685155 | Accuracy: 0.738422 == + == Epoch: 4 | Training Loss: 0.395844 | Accuracy: 86.8700% == + == Epoch: 4 | Validation Loss: 0.698182 | Accuracy: 0.741563 == + == Epoch: 5 | Training Loss: 0.249119 | Accuracy: 92.2498% == + == Epoch: 5 | Validation Loss: 0.716395 | Accuracy: 0.756903 == + == Epoch: 6 | Training Loss: 0.163563 | Accuracy: 95.0767% == + == Epoch: 6 | Validation Loss: 0.801939 | Accuracy: 0.752739 == + +``` + +## 推理过程 + +**推理前需使用finetune.py文件生成的模型检查点文件。** + +### 用法 + +执行完整的推理脚本如下: + +```shell + +python predict.py + +``` + +### 结果 + +推理结果保存在当前路径,通过prediction_log.log中看到最终预测结果。 + +```text + +2024-10-29 15:20:13,565 - INFO - Predictions: ['CD19+ B', 'CD19+ B', 'CD19+ B', 'CD19+ B', 'CD19+ B', 'CD19+ B', 'CD19+ B', 'CD19+ B', 'CD19+ B', 'CD19+ B'] + + +``` +# 随机情况说明 + +在训练中存在以下随机性来源: + +1. 数据集随机切分 +2. 随机初始化模型参数 +3. Dropout随机失活 +4. Performer中的随机投影矩阵 +5. 训练数据的随机打乱 + +为了确保结果可复现,我们: +- 使用固定随机种子(--seed参数) +- 使用确定性计算模式 + +# ModelZoo主页 + +请浏览官网[主页](https://gitee.com/mindspore/models)。 \ No newline at end of file diff --git a/scBERT/dataset_finetune.py b/scBERT/dataset_finetune.py new file mode 100644 index 0000000000000000000000000000000000000000..618da8afe16e8c8fbfe438e6f32cc1515453f5ed --- /dev/null +++ b/scBERT/dataset_finetune.py @@ -0,0 +1,53 @@ +import mindspore +from mindspore import Tensor +import numpy as np +import scanpy as sc +from sklearn.model_selection import train_test_split +import pickle as pkl + +# 微调用的带有标签的类型 +class SCDataset: + def __init__(self, data, labels, n_class): + self.data = data + self.labels = labels + self.n_class = n_class + + def __getitem__(self, index): + full_seq = self.data[index].toarray()[0] # 假设输入data是稀疏矩阵格式 + full_seq[full_seq > (self.n_class - 2)] = self.n_class - 2 + full_seq = np.append(full_seq, 0).astype(np.int32) # 添加额外的类别 + label = self.labels[index] + label = np.array(label, dtype=np.int32) + return Tensor(full_seq), Tensor(label) + + def __len__(self): + return self.data.shape[0] + + # MindSpore特定: 转换为MindSpore数据集 + def to_mind_dataset(self, batch_size=32, repeat_size=1): + def generator(): + for i in range(len(self)): + # yield self[i], + data, label = self[i] # 假设 self[i] 返回一个 (data, label) 元组 + yield (data, label) + + # 创建数据集 + types = [mindspore.int32, mindspore.int32] + c_names = ["data", "label"] + ds = mindspore.dataset.GeneratorDataset(generator, column_names=c_names, column_types=types) + ds = ds.batch(batch_size).repeat(repeat_size) + return ds + +def load_data(data_path, n_class, seed, batch_size): + data = sc.read_h5ad(data_path) + label_dict, label = np.unique(np.array(data.obs['celltype']), return_inverse=True) + with open('label_dict', 'wb') as fp: + pkl.dump(label_dict, fp) + with open('label', 'wb') as fp: + pkl.dump(label, fp) + data = data.X + X_train, X_test, y_train, y_test = train_test_split(data, label, test_size=0.1, random_state=42) + train_dataset = SCDataset(X_train, y_train, n_class).to_mind_dataset(batch_size=batch_size) + val_dataset = SCDataset( X_test, y_test, n_class).to_mind_dataset(batch_size=batch_size) + print("load data success, train num is {}, val num is {}".format(len(train_dataset), len(val_dataset))) + return train_dataset, val_dataset \ No newline at end of file diff --git a/scBERT/dataset_pretrain.py b/scBERT/dataset_pretrain.py new file mode 100644 index 0000000000000000000000000000000000000000..e87aa612c0caf1f28c1b1d2cf1e6e11c1fe7286e --- /dev/null +++ b/scBERT/dataset_pretrain.py @@ -0,0 +1,47 @@ +import mindspore +from mindspore import Tensor +import numpy as np +import scanpy as sc +from sklearn.model_selection import train_test_split +from mindspore.communication import get_group_size, get_rank + +class SCDataset: + def __init__(self, data, n_class, seq_len): + self.data = data + self.n_class = n_class + self.seq_len = seq_len + + def __getitem__(self, index): + full_seq = self.data[index].toarray()[0] # 假设输入data是稀疏矩阵格式 + full_seq[full_seq > (self.n_class - 2)] = self.n_class - 2 + full_seq = np.append(full_seq, 0).astype(np.int32) # 添加额外的类别 + return Tensor(full_seq[:self.seq_len]) + + def __len__(self): + return self.data.shape[0] + + # MindSpore特定: 转换为MindSpore数据集 + def to_mind_dataset(self, batch_size=32, repeat_size=1, DP=False): + def generator(): + for i in range(len(self)): + yield self[i], + + # 创建数据集 + types = [mindspore.int32,] + if DP: + group_size = get_group_size() + rank_id = get_rank() + ds = mindspore.dataset.GeneratorDataset(self, column_names=["data"], column_types=types, num_shards=group_size, shard_id=rank_id) + else: + ds = mindspore.dataset.GeneratorDataset(self, column_names=["data"], column_types=types) + ds = ds.batch(batch_size).repeat(repeat_size) + return ds + +def load_data(data_path, n_class, seed, batch_size, seq_len, args): + data = sc.read_h5ad(data_path) + data = data.X + data_train, data_val = train_test_split(data, test_size=0.1, random_state=seed) + train_dataset = SCDataset(data_train, n_class, seq_len).to_mind_dataset(batch_size=batch_size, DP=args.enable_dp) + val_dataset = SCDataset(data_val, n_class, seq_len).to_mind_dataset(batch_size=batch_size, DP=args.enable_dp) + print("load data success, train num is {}, val num is {}".format(len(train_dataset), len(val_dataset))) + return train_dataset, val_dataset \ No newline at end of file diff --git a/scBERT/demo/AnnData.png b/scBERT/demo/AnnData.png new file mode 100644 index 0000000000000000000000000000000000000000..e6aa27577a0d28323dcadf6deffab18faf0dbf1c Binary files /dev/null and b/scBERT/demo/AnnData.png differ diff --git a/scBERT/finetune.py b/scBERT/finetune.py new file mode 100644 index 0000000000000000000000000000000000000000..16820d15684212b1a95854fe44f61cd8d02af387 --- /dev/null +++ b/scBERT/finetune.py @@ -0,0 +1,245 @@ +import argparse +import numpy as np +from dataset_finetune import load_data +from performer import PerformerLM +from mindspore.nn import Adam, CrossEntropyLoss +from tqdm import tqdm +from mindspore import ops, save_checkpoint, Tensor +import math +from functools import reduce +import mindspore as ms +from mindspore import value_and_grad, ParallelMode, nn +from mindspore.communication import init +from mindspore import Profiler +import pickle as pkl +from sklearn.metrics import accuracy_score +import os + +# 微调中新的输出层 +class Identity(nn.Cell): + def __init__(self, dropout = 0.1, h_dim = 100, out_dim = 10): + super(Identity, self).__init__() + self.conv1 = nn.Conv2d(1, 1, (1,200), pad_mode='valid', padding=0, has_bias=False) + self.act = nn.ReLU() + self.fc1 = nn.Dense(in_channels=SEQ_LEN, out_channels=512, has_bias=True) + self.act1 = nn.ReLU() + self.dropout1 = nn.Dropout(dropout) + self.fc2 = nn.Dense(in_channels=512, out_channels=h_dim, has_bias=True) + self.act2 = nn.ReLU() + self.dropout2 = nn.Dropout(dropout) + self.fc3 = nn.Dense(in_channels=h_dim, out_channels=out_dim, has_bias=True) + + def construct(self, x): + x = x[:,None,:,:] + x = self.conv1(x) + x = self.act(x) + x = x.view(x.shape[0],-1) + x = self.fc1(x) + x = self.act1(x) + x = self.dropout1(x) + x = self.fc2(x) + x = self.act2(x) + x = self.dropout2(x) + x = self.fc3(x) + return x + +model = None +loss_fn = None + +def cum_loss_and_logits(data, label): + global model, loss_fn, SEQ_LEN + logits = model(data) + loss = loss_fn(logits, label) + return loss, logits + +def build_model(args): + global CLASS, SEQ_LEN, POS_EMBED_USING, model + #load the label stored + with open('label_dict', 'rb') as fp: + label_dict = pkl.load(fp) + model = PerformerLM( + num_tokens = CLASS, + dim = 200, + depth = 6, + max_seq_len = SEQ_LEN, + heads = 10, + ) + args = parse() + # 加载预训练权重 + ckpt_file_name = args.model_path + param_dict = ms.load_checkpoint(ckpt_file_name) + # 将权重加载到模型中 + ms.load_param_into_net(model, param_dict) + # 设置参数是否参与梯度计算 + for param in model.trainable_params(): + param.requires_grad = False + for param in model.norm.trainable_params(): + param.requires_grad = True + for param in model.performer.layers[-2].trainable_params(): + param.requires_grad = True + # 覆盖输出层 + model.to_out = Identity(dropout=0.1, h_dim=128, out_dim=label_dict.shape[0]) + print("build model success.") + count = sum([ item.size for item in model.get_parameters()]) + names = [item.name for item in model.trainable_params()] + print("param count is {}, names: {}, count: {}".format(count, str(names), len(names))) + + if args.enable_pipeline: + model.init_pipeline(0) + model.performer.layers[0].init_pipeline(1) + model.performer.layers[0].attention.init_pipeline(1) + return + +def build_optimizer_and_scheduler(model): + global LEARNING_RATE, PAD_TOKEN_ID, loss_fn, optimizer + # optimizer + optimizer = Adam(params=model.trainable_params(), learning_rate=LEARNING_RATE) + # loss + loss_fn = CrossEntropyLoss(weight=None) + print("build optimizer success.") + return optimizer + +def train_one_epoch(train_dataloader, grad_fn, optimizer): + global model + running_loss = 0.0 + cum_acc = 0.0 + model.set_train(True) + for _, (data, label) in enumerate(tqdm(train_dataloader.create_tuple_iterator())): + # forward 推理 + (loss, logits), grads = grad_fn(data, label) + optimizer(grads) + # 累加损失 + running_loss += loss.item() + # 计算精度 + final = ops.softmax(logits) + final = final.argmax(axis=-1) + # 预测数 + pred_num = Tensor([final.shape[-1]], ms.int32) + # 计算正确数 + correct_num = ops.Equal()(final, label).sum(axis=-1) + # 计算累计准确率 + cum_acc += correct_num / pred_num.mean() + del data, label, final + + return running_loss, cum_acc + +# 从 Tensor 对象中提取整数值 +def get_value_from_tensor(tensor_list): + return [tensor.asnumpy()[0] for tensor in tensor_list] + +def eval_one_epoch(val_dataloader): + global loss_fn, model, SEQ_LEN + model.set_train(False) + predictions = [] + truths = [] + running_loss = 0.0 + print("========== 开始验证") + for _, (data,label) in enumerate(tqdm(val_dataloader.create_tuple_iterator())): + logits = model(data) + loss = loss_fn(logits, label) + running_loss += loss.item() + softmax = nn.Softmax(axis=-1) + final_prob = softmax(logits) + final = final_prob.argmax(axis=-1) + predictions.append(final) + truths.append(label) + del data, logits, final + val_loss = running_loss / len(val_dataloader) + # 获取 truths 和 predictions 的实际值 + truths_values = get_value_from_tensor(truths) + predictions_values = get_value_from_tensor(predictions) + # 计算正确率 + correct_count = sum(t == p for t, p in zip(truths_values, predictions_values)) + total_count = len(truths_values) + val_acc = correct_count / total_count if total_count > 0 else 0 + # 计算正确数 + del predictions, truths + return val_loss, val_acc + +def train(optimizer, train_dataloader, val_dataloader): + global EPOCHS,VALIDATE_EVERY, MODEL_NAME, loss_fn + + train_num_step = len(train_dataloader) + grad_fn = value_and_grad(cum_loss_and_logits, grad_position=None, weights=model.trainable_params(), has_aux=True) + for epoch in range(EPOCHS): + running_loss, cum_acc = train_one_epoch(train_dataloader, grad_fn, optimizer) + # log epoch的信息 + epoch_loss = running_loss / train_num_step + epoch_acc = 100 * cum_acc / train_num_step + + # 确保将Tensor转换为Python数值 + epoch_loss_value = epoch_loss.asnumpy().item() if isinstance(epoch_loss, ms.Tensor) else epoch_loss + epoch_acc_value = epoch_acc.asnumpy().item() if isinstance(epoch_acc, ms.Tensor) else epoch_acc + + log_string = f' == Epoch: {epoch} | Training Loss: {epoch_loss_value:.6f} | Accuracy: {epoch_acc_value:6.4f}% ==' + print(log_string) + with open('finetune_result.txt', 'a') as f: + f.write(log_string + '\n') + + # 进行一次验证 + if epoch % VALIDATE_EVERY == 0: + val_loss, val_acc = eval_one_epoch(val_dataloader) + log_string = f' == Epoch: {epoch} | Validation Loss: {val_loss} | Accuracy: {val_acc.item()}% ==' + print(log_string) + with open('finetune_result.txt', 'a') as f: + f.write(log_string + '\n') + + ckpt_dir = "./" + FINETUNE_SAVE_PATH + if not os.path.exists(ckpt_dir): + os.makedirs(ckpt_dir, exist_ok=True) + ckpt_file = f"finetune-{epoch}.ckpt" + ckpt_path = os.path.join(ckpt_dir, ckpt_file) + save_checkpoint(model, ckpt_path) + +def parse(): + parser = argparse.ArgumentParser() + parser.add_argument("--enable_pipeline", type=bool, default=False, help='Local process rank.') + parser.add_argument("--device_id", type=int, default=-1, help='Local process rank.') + parser.add_argument("--bin_num", type=int, default=5, help='Number of bins.') + parser.add_argument("--gene_num", type=int, default=16906, help='Number of genes.') + parser.add_argument("--epoch", type=int, default=100, help='Number of epochs.') + parser.add_argument("--seed", type=int, default=2021, help='Random seed.') + parser.add_argument("--batch_size", type=int, default=1, help='Number of batch size.') + parser.add_argument("--learning_rate", type=float, default=1e-4, help='Learning rate.') + parser.add_argument("--grad_acc", type=int, default=60, help='Number of gradient accumulation.') + parser.add_argument("--valid_every", type=int, default=1, help='Number of training epochs between twice validation.') + parser.add_argument("--pos_embed", type=bool, default=True, help='Using Gene2vec encoding or not.') + parser.add_argument("--data_path", type=str, default='./data/Zheng68k_prepeocessed.h5ad', help='Path of data for finetune.') + parser.add_argument("--model_path", type=str, default='./ckpt/ckpt-0.ckpt', help='Path of pretrained model.') + parser.add_argument("--ckpt_dir", type=str, default='./finetune_ckpts/', help='Directory of checkpoint to save.') + parser.add_argument("--model_name", type=str, default='finetune', help='Finetuned model name.') + args = parser.parse_args() + return args + +if __name__ == "__main__": + # 1. 解析命令行参数 + args = parse() + if args.enable_pipeline: + ms.set_context(mode=0, device_target="Ascend") + ms.set_auto_parallel_context(parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, pipeline_stages=2, pipeline_result_broadcast=True) + init() + ms.set_seed(1) + else: + ms.set_context(max_device_memory='29GB') + ms.set_context(mode=0, device_target="Ascend", device_id=0) + # 2. 声明全局变量 + SEED = args.seed + EPOCHS = args.epoch + BATCH_SIZE = args.batch_size + GRADIENT_ACCUMULATION = args.grad_acc + LEARNING_RATE = args.learning_rate + SEQ_LEN = args.gene_num + 1 + VALIDATE_EVERY = args.valid_every + PATIENCE = 10 + UNASSIGN_THRES = 0.0 + CLASS = args.bin_num + 2 + POS_EMBED_USING = args.pos_embed + FINETUNE_SAVE_PATH = args.ckpt_dir + # 3. 加载数据集 + train_dataloader, val_dataloader = load_data(args.data_path, CLASS, SEED, BATCH_SIZE) + # 4. 加载模型 + build_model(args) + # 4. 构建优化器和损失函数 + optimizer = build_optimizer_and_scheduler(model) + # 5. 开始训练 + train(optimizer, train_dataloader, val_dataloader) diff --git a/scBERT/layers.py b/scBERT/layers.py new file mode 100644 index 0000000000000000000000000000000000000000..5c8f78f1941c49e5ea4448a7f437f9f291d92dcc --- /dev/null +++ b/scBERT/layers.py @@ -0,0 +1,110 @@ +from mindspore.nn import Cell, Embedding, GELU, Dense, Dropout,ReLU, LayerNorm, Softmax +from mindspore import Tensor, ops +import numpy as np +import mindspore as ms +from utils import default +import numpy as np +from mindspore import Tensor + +class Gene2VecPositionalEmbedding(Cell): + def __init__(self, max_seq_len=16907): + super().__init__() + gene2vec_weight = np.load('./data/gene2vec_16906.npy') + gene2vec_weight = gene2vec_weight + gene2vec_weight = np.concatenate((gene2vec_weight, np.zeros((1, gene2vec_weight.shape[1]))), axis=0) + gene2vec_weight = Tensor(gene2vec_weight, dtype=ms.float32) + self.emb = Embedding(vocab_size=max_seq_len, embedding_size=200, embedding_table=gene2vec_weight, dtype=ms.float32) + self.emb.embedding_table.requires_grad=False + + def construct(self, x): + t = ops.arange(start=0, end=x.shape[1], dtype=ms.int32) + return self.emb(t) + +class BatchMatricMul(Cell): + def __init__(self, transpose_a=False, transpose_b=False): + super().__init__() + self.matmul = ops.BatchMatMul(transpose_a, transpose_b) + + def construct(self, a, b): + return self.matmul(a, b) + +class Add(Cell): + def __init__(self): + super().__init__() + self.add = ops.Add() + + def construct(self, a, b): + return self.add(a, b) + +class SelfAttention(Cell): + def __init__( + self, + dim, + heads = 8, + dim_head = 64, + dropout = 0., + ): + super().__init__() + assert dim % heads == 0, 'dimension must be divisible by number of heads' + self.dim_head = default(dim_head, dim // heads) + self.inner_dim = dim_head * heads + + self.heads = heads + self.reshape = ops.Reshape() + # stage 1 + self.to_q = Dense(in_channels=dim, out_channels=self.inner_dim, dtype=ms.float32, has_bias=False) + self.to_k = Dense(in_channels=dim, out_channels=self.inner_dim, dtype=ms.float32, has_bias=False) + self.to_v = Dense(in_channels=dim, out_channels=self.inner_dim, dtype=ms.float32, has_bias=False) + self.to_out = Dense(in_channels=self.inner_dim, out_channels=dim, dtype=ms.float32, has_bias=False) + self.dropout1 = Dropout(p=dropout+0.00000001) + + # stage 2 + self.matmul = BatchMatricMul(False, True) + self.softmax = Softmax(axis=-1) + self.mul = BatchMatricMul() + self.layer_norm = LayerNorm((dim,)) + self.w1 = Dense(in_channels=dim, out_channels=dim * 4, dtype=ms.float32) + self.act = GELU() + self.dropout2 = Dropout(p=dropout+0.000001) + self.w2 = Dense(in_channels=dim * 4, out_channels=dim, dtype=ms.float32) + self.add1 = Add() + self.add2= Add() + + def construct(self, x): + # (batch_size, 16906, 200, 10, 10) + b, n, _, h = *x.shape, self.heads + + # 这里就是 [bs, seq, hidden] -> [bs, head, seq, head_dim] + q, k, v = self.to_q(x), self.to_k(x), self.to_v(x) + q = self.reshape(q, (b, h, n, self.dim_head)) + k = self.reshape(k, (b, h, n, self.dim_head)) + v = self.reshape(v, (b, h, n, self.dim_head)) + score = self.matmul(q, k) + out = self.mul(self.softmax(score), v) + out = self.reshape(out, (b, n, self.inner_dim)) + attn_out = self.to_out(out) + attn_out = self.add1(x, attn_out) + x = self.layer_norm(attn_out) + x = self.w1(x) + x = self.act(x) + x = self.dropout1(x) + x = self.w2(x) + out = self.add2(attn_out,x) + return self.dropout2(attn_out) + + def init_pipeline(self): + self.to_q.pipeline_stage = 0 + self.to_k.pipeline_stage = 0 + self.to_v.pipeline_stage = 0 + self.dropout1.pipeline_stage = 1 + self.matmul.pipeline_stage = 1 + self.mul.pipeline_stage = 1 + self.softmax.pipeline_stage=1 + self.to_out.pipeline_stage = 2 + self.add1.pipeline_stage = 2 + self.w1.pipeline_stage=2 + self.act.pipeline_stage=2 + self.dropout1.pipeline_stage=2 + self.w2.pipeline_stage=3 + self.add2.pipeline_stage=3 + self.dropout2.pipeline_stage=3 \ No newline at end of file diff --git a/scBERT/performer.py b/scBERT/performer.py new file mode 100644 index 0000000000000000000000000000000000000000..955e36b6271fe360d69aec3da692a2043d6ef5a8 --- /dev/null +++ b/scBERT/performer.py @@ -0,0 +1,425 @@ +import math +import numpy as np +import mindspore as ms +import mindspore.common.dtype as mstype +import mindspore.nn as nn +import mindspore.ops.functional as F +from mindspore.ops import operations as P +from mindspore.common.tensor import Tensor +from mindspore.common.parameter import Parameter +from mindspore.common.initializer import initializer, Normal +from layers import Gene2VecPositionalEmbedding + +# helpers +def exists(val): + return val is not None +def empty(tensor): + return tensor.numel() == 0 +def default(val, d): + return val if exists(val) else d + +def softmax_kernel(data, projection_matrix, is_query=False, normalize_data=True, eps=1e-4): + """ + data:[Batch,Heads,Seq,Dim_head] + projection_matrix:[m,Dim_head] + + """ + b, h, Seq,Dim_head= data.shape + data_normalizer = (data.shape[-1] ** -0.25) if normalize_data else 1. + ratio = (projection_matrix.shape[0] ** -0.5) + # W'*X + data_dash = data_normalizer * P.MatMul(transpose_b=True)(P.Reshape()(data,(-1,Dim_head)), projection_matrix) + data_dash = P.Reshape()(data_dash,(b,h,Seq,-1)) + # |X|^2/2 + diag_data = data ** 2 + diag_data = P.ReduceSum(keep_dims=True)(diag_data, -1) + diag_data = (diag_data / 2.0) * (data_normalizer ** 2) + #exp(W'x-|X|^2/2) + if is_query: + data_dash = ratio * ( + P.Exp()(data_dash - diag_data - + P.ReduceMax(keep_dims=True)(data_dash, -1)) + eps) + else: + data_dash = ratio * ( + P.Exp()(data_dash - diag_data - P.ReduceMax()(data_dash)) + eps) + + return data_dash + +def orthogonal_matrix_chunk(cols, qr_uniform_q = False): + unstructured_block = np.random.randn(cols, cols).astype(np.float32) + q, r = np.linalg.qr(unstructured_block, mode='reduced') + # proposed by @Parskatt + # to make sure Q is uniform https://arxiv.org/pdf/math-ph/0609050.pdf + if qr_uniform_q: + d = np.diag(r, 0) + q *= np.sign(d) + # 转mindspore Tensor + q = np.transpose(q) + q = Tensor(q) + return q + +def gaussian_orthogonal_random_matrix(nb_rows, nb_columns, scaling = 0, qr_uniform_q = False): + nb_full_blocks = int(nb_rows / nb_columns) + block_list = [] + for _ in range(nb_full_blocks): + q = orthogonal_matrix_chunk(nb_columns, qr_uniform_q = qr_uniform_q, ) + block_list.append(q) + remaining_rows = nb_rows - nb_full_blocks * nb_columns + if remaining_rows > 0: + q = orthogonal_matrix_chunk(nb_columns, qr_uniform_q = qr_uniform_q, ) + block_list.append(q[:remaining_rows]) + final_matrix = P.Concat()(tuple(block_list)) + + if scaling == 0: + multiplier = Tensor(np.diag(np.linalg.norm(np.random.randn(nb_rows, nb_columns).astype(np.float32), axis = 1))) + elif scaling == 1: + multiplier = Tensor(np.diag(math.sqrt((float(nb_columns))) * np.ones((nb_rows,)))) + else: + raise ValueError(f'Invalid scaling {scaling}') + + return P.MatMul()(multiplier,final_matrix) + +class Softmax_kernel(nn.Cell): + def __init__(self): + super().__init__() + self.Reshape = P.Reshape() + self.MatMul_b = P.MatMul(transpose_b=True) + self.ReduceSum = P.ReduceSum(keep_dims=True) + self.Exp = P.Exp() + self.ReduceMax_keep = P.ReduceMax(keep_dims=True) + self.ReduceMax = P.ReduceMax() + def construct(self, data, projection_matrix, is_query=False, normalize_data=True, eps=1e-4): + """ + data:[Batch,Heads,Seq,Dim_head] + projection_matrix:[m,Dim_head] + + """ + b, h, Seq, Dim_head = data.shape + data_normalizer = (data.shape[-1] ** -0.25) if normalize_data else 1. + ratio = (projection_matrix.shape[0] ** -0.5) + # W'*X + data_dash = data_normalizer * self.MatMul_b(self.Reshape(data, (-1, Dim_head)), projection_matrix) + data_dash = self.Reshape(data_dash, (b, h, Seq, -1)) + # |X|^2/2 + diag_data = data ** 2 + diag_data = self.ReduceMax_keep(diag_data, -1) + diag_data = (diag_data / 2.0) * (data_normalizer ** 2) + # exp(W'x-|X|^2/2) + if is_query: + data_dash = ratio * ( + self.Exp(data_dash - diag_data - + self.ReduceMax_keep(data_dash, -1)) + eps) + else: + data_dash = ratio * ( + self.Exp(data_dash - diag_data - self.ReduceMax(data_dash)) + eps) + + return data_dash + +class Linear_attention(nn.Cell): + def __init__(self): + super().__init__() + self.ReduceSum =P.ReduceSum(keep_dims=True) + self.BatchMatMul_b = P.BatchMatMul(transpose_b=True) + self.BatchMatMul_a = P.BatchMatMul(transpose_a=True) + self.BatchMatMul = P.BatchMatMul() + self.Mul = P.Mul() + def construct(self, q, k, v): + """ + k,q,v:[B,Sq,H] + """ + # [B,1,H] + k_cumsum = self.ReduceSum(k, -2) + # [B,Sq,1] + D_inv = 1. /self.BatchMatMul_b(q, k_cumsum) + # [B,H,H] + context = self.BatchMatMul_a(k, v) + # [B,Sq,H] + out = self.BatchMatMul(q, context) + # [B,Sq,H]*[B,Sq,1] -> + out = self.Mul(out, D_inv) + return out + +class Causal_linear_attention(nn.Cell): + def __init__(self): + super().__init__() + self.view_ = P.Reshape() + self.CumSum = P.CumSum() + self.ReduceSum =P.ReduceSum(keep_dims=True) + self.BatchMatMul_b = P.BatchMatMul(transpose_b=True) + self.BatchMatMul_a = P.BatchMatMul(transpose_a=True) + self.Mul = P.Mul() + def construct(self, q, k, v): + k_cumsum = self.CumSum(k, -2) + # [n,] + D_inv = 1. / self.ReduceSum(q * k_cumsum, -1) + # [n,d,1]*[n,1,e] -> [n,d,e] + context = self.BatchMatMul_b(self.view_(k, k.shape + (1,)), self.view_(v, v.shape + (1,))) + #[n,d,e] -> + context = self.CumSum(context,-3) + # [n,1,d] * [n,d,e] -> [n,1,e] = [n,e] + out = self.BatchMatMul_a(self.view_(q, q.shape + (1,)), context) + out = self.view_(out, v.shape) + out = self.Mul(out, D_inv) + return out + +class LayerNorm(nn.Cell): + """ + Layer Normalization + + Args: + normalized_shape: the corresponding shape of the normalized axes + eps: epsilon, a small number avoiding zero division + + Inputs: + x: input tensor + + Returns: + rescaled_output: Tensor, returned tensor after layernorm + """ + def __init__(self, normalized_shape, eps=1e-5): + super(LayerNorm, self).__init__() + self.gamma = Parameter(initializer('ones', normalized_shape), name="gamma") + self.beta = Parameter(initializer('zeros', normalized_shape), name="beta") + self.mean = P.ReduceMean(keep_dims=True) + self.eps = eps + + def construct(self, x): + mean = self.mean(x, -1) + variance = self.mean(F.square(x - mean), -1) + output = (x - mean) / F.sqrt(variance + self.eps) + rescaled_output = output * self.gamma + self.beta + return rescaled_output + +class FeedForward(nn.Cell): + def __init__(self, dim, + mult = 4, + initializer_range=0.02, + hidden_dropout_prob=0.1, + compute_type=mstype.float32): + super(FeedForward,self).__init__() + self.hidden_size = dim + self.w1 = Mapping(dim,dim*mult,initializer_range,compute_type) + self.w2 = Mapping(dim * mult,dim,initializer_range,compute_type) + self.act = nn.GELU() + self.dropout = nn.Dropout(hidden_dropout_prob) + def construct(self, x): + x = self.w1(x) + x = self.act(x) + x = self.w2(x) + x = self.dropout(x) + return x + +class Mapping(nn.Cell): + """ + A mapping function with a 3d input + Args: + input_size: the size of the last dimension of the input tensor + output_size: the desired size of the last dimension of the output tensor + dtype: the compute datatype + scale: the scale factor for initialization + Inputs: + x: the 3d input + Returns: + output: Tensor, a 3d tensor after projection + """ + def __init__(self, input_size, output_size,initializer_range=0.02, dtype=ms.float32, scale=1.0): + super(Mapping, self).__init__() + self.output_size = output_size + self.input_size = input_size + self.weight = Parameter(initializer(Normal(sigma=initializer_range*scale), [input_size, output_size]),name="Weight") + self.bias = Parameter(initializer("zeros", [output_size,]),name="Bias") + self.dtype = dtype + self.cast = P.Cast() + + def construct(self, x): + out_shape = P.Shape()(x)[:-1] + (self.output_size,) + x = P.Reshape()(x, (-1, self.input_size)) + x = nn.MatMul()(x, self.cast(self.weight, self.dtype)) + self.cast(self.bias, self.dtype) + output = P.Reshape()(x, out_shape) + return output + +class FastAttention(nn.Cell): + def __init__(self, dim_heads, nb_features = None, ortho_scaling = 0, causal = False, qr_uniform_q = False): + super(FastAttention, self).__init__() + nb_features = default(nb_features, int(dim_heads * math.log(dim_heads))) + self.dim_heads = dim_heads + self.nb_features = nb_features + self.ortho_scaling = ortho_scaling + ## projection_matrix is buffer + self.projection_matrix = gaussian_orthogonal_random_matrix(nb_rows=self.nb_features, + nb_columns=dim_heads, + scaling=ortho_scaling, + qr_uniform_q=qr_uniform_q) + self.causal = causal + self.attn_fn = Linear_attention() if not self.causal else Causal_linear_attention() + self.softmax_kernel = Softmax_kernel() + def construct(self, q, k, v): + q = self.softmax_kernel(data=q, projection_matrix=self.projection_matrix, is_query=True) + k = self.softmax_kernel(data=k, projection_matrix=self.projection_matrix, is_query=False) + out = self.attn_fn(q, k, v) + return out + +class SelfAttention(nn.Cell): + def __init__(self, dim, heads, dim_head, causal=False, nb_features=None, qr_uniform_q = False, dropout = 0.9): + super(SelfAttention,self).__init__() + assert dim % heads == 0, 'dimension must be divisible by number of heads' + self.dim_head = dim_head + self.fast_attention = FastAttention(dim_heads=self.dim_head, nb_features=nb_features, causal=causal, qr_uniform_q=qr_uniform_q) + self.heads = heads + self.to_q = Mapping(dim, dim) + self.to_k = Mapping(dim, dim) + self.to_v = Mapping(dim, dim) + self.to_out = Mapping(dim, dim) + self.dropout = nn.Dropout(dropout) + self.view = P.Reshape() + self.Concat = P.Concat(axis=1) + self.Mul = P.Mul() + self.ExpandDims = P.ExpandDims() + self.Tile = P.Tile() + def construct(self, x): + """ + #b:batch_size + #h:num_heads + #n:seq_len + #d:dim_perhead + """ + b, n, dim, = x.shape + h = self.heads + + q, k, v = self.to_q(x), self.to_k(x), self.to_v(x) + q, k, v = self.view(q, (b,h,n,self.dim_head)), self.view(k, (b,h,n,self.dim_head)), self.view(v, (b,h,n,self.dim_head)) + + out = self.fast_attention(q, k, v) + out = self.view(out, (b,n,h* self.dim_head)) + out = self.to_out(out) + + return self.dropout(out) + +class EmbeddingLookup(nn.Cell): + """ + A embeddings lookup table with a fixed dictionary and size. + + Args: + vocab_size (int): Size of the dictionary of embeddings. + embedding_size (int): The size of each embedding vector. + use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False. + initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02. + """ + def __init__(self, + vocab_size, + embedding_size, + use_one_hot_embeddings=False, + initializer_range=0.02): + super(EmbeddingLookup, self).__init__() + self.vocab_size = vocab_size + self.embedding_size = embedding_size + self.use_one_hot_embeddings = use_one_hot_embeddings + self.embedding_table = Parameter(initializer(Normal(sigma=initializer_range), + [vocab_size, embedding_size]), name="embedding_table") + self.expand = P.ExpandDims() + self.shape_flat = (-1,) + self.gather = P.GatherV2() + self.one_hot = P.OneHot() + self.on_value = Tensor(1.0, mstype.float32) + self.off_value = Tensor(0.0, mstype.float32) + self.array_mul = P.MatMul() + self.reshape = P.Reshape() + self.shape = P.Shape() + + def construct(self, input_ids): + """Get a embeddings lookup table with a fixed dictionary and size.""" + input_shape = self.shape(input_ids) + + flat_ids = self.reshape(input_ids, self.shape_flat) + if self.use_one_hot_embeddings: + one_hot_ids = self.one_hot(flat_ids, self.vocab_size, self.on_value, self.off_value) + output_for_reshape = self.array_mul(one_hot_ids, self.embedding_table) + else: + output_for_reshape = self.gather(self.embedding_table, flat_ids, 0) + + out_shape = input_shape + (self.embedding_size,) + output = self.reshape(output_for_reshape, out_shape) + return output + +class AbsolutePositionalEmbedding(nn.Cell): + def __init__(self, dim, max_seq_len): + super(AbsolutePositionalEmbedding, self).__init__() + self.emb = nn.EmbeddingLookup(max_seq_len, dim) + + def construct(self, x): + batch_size, seq_length = x.shape[0], x.shape[1] + input_position = F.tuple_to_array(F.make_range(seq_length)) + # input_position = P.Tile()(input_position, (batch_size, 1)) + return self.emb(input_position) + + +class Performer_layer(nn.Cell): + def __init__(self,dim, heads, dim_head, causal=False, nb_features=None, qr_uniform_q = False, dropout = 0.9): + super(Performer_layer, self).__init__() + self.SelfAttention = SelfAttention(dim, heads, dim_head, causal, nb_features, qr_uniform_q, dropout) + self.FeedForward = FeedForward(dim=dim) + self.LayerNorm = LayerNorm(dim,) + def construct(self, x): + + x = self.LayerNorm(x) + out = x + self.SelfAttention(x) + out = self.LayerNorm(out) + out = out + self.FeedForward(x) + return out + +class Performer(nn.Cell): + def __init__(self,dim, depth, heads, causal=False, nb_features=None, qr_uniform_q = False, dropout = 0.9): + super(Performer, self).__init__() + assert dim % heads == 0 + dim_head = dim//heads + layers = [] + for _ in range(depth): + layers.append(Performer_layer(dim=dim, heads=heads, + dim_head=dim_head, + causal=causal, + nb_features=nb_features, + qr_uniform_q=qr_uniform_q, + dropout=dropout )) + + self.layers = nn.CellList(layers) + + def construct(self, input_tensor): + prev_output = input_tensor + for layer_module in self.layers: + prev_output = layer_module(prev_output) + return prev_output + +class PerformerLM(nn.Cell): + def __init__(self, num_tokens, max_seq_len, dim, depth, heads, causal = True, + nb_features = None, emb_dropout = 0.9, pf_dropout = 0.9, qr_uniform_q = False): + super(PerformerLM,self).__init__() + self.max_seq_len = max_seq_len + self.dim = dim + self.num_tokens = num_tokens + self.token_emb = EmbeddingLookup(num_tokens, dim) + # self.pos_emb = AbsolutePositionalEmbedding(dim, max_seq_len) + self.pos_emb = Gene2VecPositionalEmbedding() + self.dropout = nn.Dropout(emb_dropout) + self.performer = Performer(dim, depth, heads, causal, nb_features, qr_uniform_q, pf_dropout ) + self.norm = LayerNorm(dim) + self.MatMul = P.MatMul(transpose_b=True) + self.Reshape = P.Reshape() + self.to_out = nn.Dense(dim, num_tokens, dtype=ms.float32) + def construct(self, input_ids): + # b, n = input_ids.shape + # assert n <= self.max_seq_len, f'sequence length {n} must be less than the max sequence length {self.max_seq_len}' + # token and positional embeddings + + x = self.token_emb(input_ids) + x += self.pos_emb(x) + x = self.dropout(x) + x = self.performer(x) + # norm and to logits + #[batch,seq,hidden] + x = self.norm(x) + # res = self.MatMul(self.Reshape(x,(-1,self.dim)), self.token_emb.embedding_table) + # return self.Reshape(res, input_ids.shape+(self.num_tokens,)) + # 5. (batch, 16906, 200) -> (batch, 16906, 7) + # 输出层 + x = self.to_out(x) + return x \ No newline at end of file diff --git a/scBERT/predict.py b/scBERT/predict.py new file mode 100644 index 0000000000000000000000000000000000000000..a3c68d689ac5bcbd2636c4c8a1dc44d47faf6e89 --- /dev/null +++ b/scBERT/predict.py @@ -0,0 +1,129 @@ +import argparse +import numpy as np +from dataset_finetune import load_data +from performer import PerformerLM +from mindspore.nn import Adam, CrossEntropyLoss +from tqdm import tqdm +from mindspore import ops, save_checkpoint, Tensor +import math +from functools import reduce +import mindspore as ms +from mindspore import value_and_grad, ParallelMode, nn +from mindspore.communication import init +from mindspore import Profiler +import pickle as pkl +from sklearn.metrics import accuracy_score +import logging +import scanpy as sc + +class Identity(nn.Cell): + def __init__(self, dropout = 0.1, h_dim = 100, out_dim = 10): + super(Identity, self).__init__() + self.conv1 = nn.Conv2d(1, 1, (1,200), pad_mode='valid', padding=0, has_bias=False) + self.act = nn.ReLU() + self.fc1 = nn.Dense(in_channels=SEQ_LEN, out_channels=512, has_bias=True) + self.act1 = nn.ReLU() + self.dropout1 = nn.Dropout(dropout) + self.fc2 = nn.Dense(in_channels=512, out_channels=h_dim, has_bias=True) + self.act2 = nn.ReLU() + self.dropout2 = nn.Dropout(dropout) + self.fc3 = nn.Dense(in_channels=h_dim, out_channels=out_dim, has_bias=True) + + def construct(self, x): + x = x[:,None,:,:] + # [batch, 1, seq_len, 200] + x = self.conv1(x) + # [batch, 1, seq_len, 1] + x = self.act(x) + x = x.view(x.shape[0],-1) + x = self.fc1(x) + x = self.act1(x) + x = self.dropout1(x) + x = self.fc2(x) + x = self.act2(x) + x = self.dropout2(x) + x = self.fc3(x) + return x + +def parse(): + parser = argparse.ArgumentParser() + parser.add_argument("--enable_pipeline", type=bool, default=False, help='Local process rank.') + parser.add_argument("--device_id", type=int, default=-1, help='Local process rank.') + parser.add_argument("--bin_num", type=int, default=5, help='Number of bins.') + parser.add_argument("--gene_num", type=int, default=16906, help='Number of genes.') + parser.add_argument("--epoch", type=int, default=100, help='Number of epochs.') + parser.add_argument("--seed", type=int, default=2021, help='Random seed.') + parser.add_argument("--pos_embed", type=bool, default=True, help='Using Gene2vec encoding or not.') + parser.add_argument("--data_path", type=str, default='./data/Zheng68k_prepeocessed.h5ad', help='Path of data for predict.') + parser.add_argument("--model_path", type=str, default='./ckpt/ckpt-0.ckpt', help='Path of finetuned model.') + args = parser.parse_args() + return args + +if __name__ == "__main__": + # 配置日志记录到文件 'prediction_log.log' + logging.basicConfig( + filename='prediction_log.log', + filemode='a', # 'a' 表示追加模式,如果要覆盖日志文件,可以使用 'w' + format='%(asctime)s - %(levelname)s - %(message)s', + level=logging.INFO + ) + + # 解析命令行参数 + args = parse() + if args.enable_pipeline: + ms.set_context(mode=0, device_target="Ascend") + ms.set_auto_parallel_context(parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, pipeline_stages=2, pipeline_result_broadcast=True) + init() + ms.set_seed(1) + else: + ms.set_context(variable_memory_max_size='29GB') + ms.set_context(mode=0, device_target="Ascend", device_id=0) + + # 声明全局变量 + SEED = args.seed + EPOCHS = args.epoch + SEQ_LEN = args.gene_num + 1 + CLASS = args.bin_num + 2 + POS_EMBED_USING = args.pos_embed + + # 读预测数据集 + data = sc.read_h5ad(args.data_path) + # 标签字典 + with open('label_dict', 'rb') as fp: + label_dict = pkl.load(fp) + data = data.X[:10] + + # 加载模型 + model = PerformerLM( + num_tokens = CLASS, + dim = 200, + depth = 6, + max_seq_len = SEQ_LEN, + heads = 10, + # local_attn_heads = 0, + # g2v_position_emb = True + ) + model.to_out = Identity(dropout=0.1, h_dim=128, out_dim=label_dict.shape[0]) + path = args.model_path + ckpt = ms.load_checkpoint(path) + ms.load_param_into_net(model, ckpt) + for param in model.trainable_params(): + param.requires_grad = False + + batch_size = data.shape[0] + model.set_train(False) + pred_finals = [] + for index in range(batch_size): + full_seq = data[index].toarray()[0] + full_seq[full_seq > (CLASS - 2)] = CLASS - 2 + full_seq = np.append(full_seq, 0).astype(np.int32) + full_seq = Tensor(full_seq).astype(ms.int32) # 转换为 MindSpore Tensor + full_seq = ops.expand_dims(full_seq, 0) # 在第 0 维度添加一个维度,类似于 unsqueeze + pred_logits = model(full_seq) + pred_prob = ops.softmax(pred_logits, axis = -1) + pred_final = pred_prob.argmax(axis=-1) + pred_finals.append(pred_final) + pred_list = [] + for pred_final in pred_finals: + pred_list.append(label_dict[pred_final]) + logging.info(f"Predictions: {pred_list}") \ No newline at end of file diff --git a/scBERT/pretrain.py b/scBERT/pretrain.py new file mode 100644 index 0000000000000000000000000000000000000000..3fca59bc6ccdea31bf4e734102b47da4a0db54cf --- /dev/null +++ b/scBERT/pretrain.py @@ -0,0 +1,303 @@ +import argparse +from dataset_pretrain import load_data +from performer import PerformerLM +from mindspore.nn import Adam, CrossEntropyLoss +from tqdm import tqdm +from mindspore import ops, save_checkpoint, Tensor +import math +from functools import reduce +import mindspore as ms +from mindspore import value_and_grad +from mindspore.communication import init +from mindspore.communication.management import get_rank +from mindspore import nn +from mindspore.nn import ExponentialDecayLR +import logging + +model = None +loss_fn = None + +def prob_mask_like(t, prob): + return ops.uniform(t.shape, Tensor(0, dtype=ms.float32), Tensor(1, dtype=ms.float32)).float() < prob + +def mask_with_tokens(t, token_ids): + init_no_mask = ops.full_like(t, False, dtype=ms.uint8) + mask = reduce(lambda acc, el: acc | (t == el), token_ids, init_no_mask) + return Tensor(mask, dtype=ms.uint8) + +def get_mask_subset_with_prob(mask, prob): + batch, seq_len = mask.shape + max_masked = math.ceil(prob * seq_len) # num of mask of a single sequence in average + num_tokens = mask.sum(axis=-1, keepdims=True) # num of pure tokens of each sequence except special tokens + mask_excess = ops.cat((ops.zeros(size=(batch), dtype=ms.float32), ops.arange(1, seq_len,dtype=ms.float32).repeat(batch))).reshape(batch,seq_len) + mask_excess = (mask_excess >= (num_tokens * prob).ceil()) # only 15% of pure tokens can be masked + mask_excess = ops.Reshape()(mask_excess, (batch, seq_len)) + mask_excess = mask_excess[:, :max_masked] # get difference between 15% of pure tokens and 15% of all tokens + rand = ops.rand((batch, seq_len)).masked_fill(~mask, -1e9) # rand (0-1) as prob, special token use -1e9 + _, sampled_indices = rand.topk(max_masked, dim=-1) # get index of topk prob to mask + sampled_indices = (sampled_indices + 1).masked_fill(mask_excess, 0) # delete difference of mask not pure + new_mask = ops.zeros((batch, seq_len + 1), dtype=ms.uint8) # get (batch, seq_len) shape zero matrix + new_mask = new_mask.scatter(-1, sampled_indices, ops.ones(shape=ops.shape(sampled_indices), dtype=ms.uint8)) # set masks in zero matrix as 1 + new_mask = ops.Cast()(new_mask, ms.uint8) + return new_mask[:, 1:] # the final mask, True is mask + +def data_mask(data, + mask_prob=None, + replace_prob=None, + num_tokens=None, + random_token_prob=None, + mask_token_id=None, + pad_token_id=None, + mask_ignore_token_ids=None +): + global MASK_PROB, REPLACE_PROB, RANDOM_TOKEN_PROB, MASK_TOKEN_ID, PAD_TOKEN_ID, MASK_IGNORE_TOKEN_IDS + replace_prob = REPLACE_PROB + mask_prob= MASK_PROB + random_token_prob = RANDOM_TOKEN_PROB + mask_token_id = MASK_TOKEN_ID + pad_token_id = PAD_TOKEN_ID + mask_ignore_token_ids = MASK_IGNORE_TOKEN_IDS + + mask_ignore_token_ids = set([*mask_ignore_token_ids, pad_token_id]) + # do not mask [pad] tokens, or any other tokens in the tokens designated to be excluded ([cls], [sep]) + # also do not include these special tokens in the tokens chosen at random + no_mask = mask_with_tokens(data, mask_ignore_token_ids) # ignore_token as True, will not be masked later + mask = get_mask_subset_with_prob(~no_mask, mask_prob) # get the True/False mask matrix + # get mask indices + ## mask_indices = torch.nonzero(mask, as_tuple=True) # get the index of mask(nonzero value of mask matrix) + # mask input with mask tokens with probability of `replace_prob` (keep tokens the same with probability 1 - replace_prob) + masked_input = data + # if random token probability > 0 for mlm + if random_token_prob > 0: + assert num_tokens is not None, 'num_tokens keyword must be supplied when instantiating MLM if using random token replacement' + random_token_prob = prob_mask_like(data, random_token_prob) # get the mask matrix of random token replace + random_tokens = ops.randint(0, num_tokens, data.shape) # generate random token matrix with the same shape as input + random_no_mask = mask_with_tokens(random_tokens, mask_ignore_token_ids) # not masked matrix for the random token matrix + random_token_prob &= ~random_no_mask # get the pure mask matrix of random token replace + random_indices = ops.nonzero(random_token_prob, as_tuple=True) # index of random token replace + masked_input[random_indices] = random_tokens[random_indices] # replace some tokens by random token + # [mask] input + replace_prob = prob_mask_like(data, replace_prob) # get the mask matrix of token being masked + masked_input = masked_input.masked_fill(ops.Cast()(mask * replace_prob, ms.bool_), mask_token_id) # get the data has been masked by mask_token + # mask out any tokens to padding tokens that were not originally going to be masked + labels = data.masked_fill(~mask, pad_token_id) # the label of masked tokens + return masked_input, labels + +def build_model(args): + global CLASS, SEQ_LEN, POS_EMBED_USING, model + model = PerformerLM( + num_tokens = CLASS, # 7 + dim = 200, + depth = 1, + max_seq_len = SEQ_LEN, # 16907 + heads = 10, + #local_attn_heads = 0, + ) + print("build model success.") + count = sum([ item.size for item in model.get_parameters()]) + names = [item.name for item in model.trainable_params()] + + print("param count is {}, names: {}, count: {}".format(count, str(names), len(names))) + + if args.enable_pipeline: + model.init_pipeline() + model.performer.layers[0].init_pipeline() + model.performer.layers[0].attention.init_pipeline() + return model + + +def build_optimizer_and_scheduler(model): + global LEARNING_RATE, PAD_TOKEN_ID, loss_fn, optimizer + # optimizer + optimizer = Adam(params=model.trainable_params(), learning_rate=LEARNING_RATE) + loss_fn = CrossEntropyLoss(ignore_index = PAD_TOKEN_ID, reduction='mean') + print("build optimizer success.") + return optimizer, loss_fn + +def train_one_epoch(train_dataloader, grad_fn, optimizer, pp_grad_reducer): + global PAD_TOKEN_ID, model, PIPELINE, BATCH_SIZE, SEQ_LEN, DP, lr_schedule + running_loss = 0.0 + cum_acc = 0.0 + model.set_train(True) + correct_num = 0 + val_num = 0 + for index, (data,) in enumerate(tqdm(train_dataloader.create_tuple_iterator())): + data, orig_labels = data_mask(data) + labels = ops.repeat_elements(orig_labels, rep=7, axis=-1) + labels = ops.cast(labels, dtype=ms.float32) + if (labels.shape[0] % BATCH_SIZE) !=0: + continue + labels = ops.reshape(labels, (BATCH_SIZE,SEQ_LEN, 7)) + if PIPELINE: + loss, grads = grad_fn(data, labels) + grads = pp_grad_reducer(grads) + elif DP: + (loss, logits), grads = grad_fn(data, labels) + grads = pp_grad_reducer(grads) + else: + (loss, logits), grads = grad_fn(data, labels) + optimizer(grads) + lr = lr_schedule(index) + optimizer.learning_rate = lr + # 累加损失 + running_loss += loss.item() / (SEQ_LEN*BATCH_SIZE*7) + # 计算精度 + if not PIPELINE: + labels = ops.repeat_elements(orig_labels, rep=7, axis=-1) + labels = ops.reshape(labels, (-1, SEQ_LEN, 7)) + labels = ops.cast(labels, dtype=ms.float32) + final = ops.softmax(logits, axis=-1)[..., 1:-1] # (bs, seq_len, 7) + final = final.argmax(axis=-1) + 1 # # (bs, seq_len) + correct_num += ops.mul(Tensor(orig_labels!=PAD_TOKEN_ID, dtype=ms.uint8), Tensor(final == orig_labels, dtype=ms.uint8)).sum(axis=-1).sum() + val_num += Tensor(orig_labels != PAD_TOKEN_ID, dtype=ms.uint8).sum(axis=-1).sum() + del data, labels, logits, final, orig_labels + + return running_loss, 100 * correct_num / val_num + +def eval_one_epoch(val_dataloader): + global PAD_TOKEN_ID, loss_fn, model, SEQ_LEN + model.set_train(False) + predictions = [] + truths = [] + running_loss = 0.0 + print("========== 开始验证") + correct_num = 0 + val_num = 0 + for _, (data,) in enumerate(tqdm(val_dataloader.create_tuple_iterator())): + data, ori_labels = data_mask(data) + ori_labels = ops.cast(ori_labels, ms.float32) + labels = ops.repeat_elements(ori_labels, rep=7, axis=-1) + labels = ops.reshape(labels, (-1, SEQ_LEN, 7)) + labels = ops.cast(labels, dtype=ms.float32) + logits = model(data) + loss = loss_fn(logits, labels) + running_loss += loss.item() / (SEQ_LEN*BATCH_SIZE*7) + final = ops.softmax(logits, axis=-1)[..., 1:-1] + final = final.argmax(axis=-1) + 1 + correct_num += ops.mul(Tensor(ori_labels!=PAD_TOKEN_ID, dtype=ms.uint8), Tensor(final == ori_labels, dtype=ms.uint8)).sum(axis=-1).sum() + val_num += Tensor(ori_labels != PAD_TOKEN_ID, dtype=ms.uint8).sum(axis=-1).sum() + del data, labels, logits, final, ori_labels + val_loss = running_loss / len(val_dataloader) + val_acc = 100 * correct_num / val_num + del predictions, truths + return val_loss, val_acc + +def train(optimizer, train_dataloader, val_dataloader): + global EPOCHS,VALIDATE_EVERY, MODEL_NAME, loss_fn, PIPELINE, DP + + train_num_step = len(train_dataloader) + if PIPELINE: + grad_fn = value_and_grad(forward_pipeline, grad_position=None, weights=optimizer.parameters) + pp_grad_reducer = nn.PipelineGradReducer(optimizer.parameters) + elif DP: + grad_fn = value_and_grad(forward, grad_position=None, weights=model.trainable_params(), has_aux=True) + pp_grad_reducer = nn.DistributedGradReducer(optimizer.parameters, mean=True) + else: + grad_fn = value_and_grad(forward, grad_position=None, weights=model.trainable_params(), has_aux=True) + pp_grad_reducer = None + + for epoch in range(EPOCHS): + running_loss, cum_acc = train_one_epoch(train_dataloader, grad_fn, optimizer, pp_grad_reducer) + # log epoch的信息 + epoch_loss = running_loss / train_num_step + logging.info(f' == Epoch: {epoch} | Training Loss: {epoch_loss:.6f} | Accuracy: {cum_acc.item():6.4f}% ==') + + # 进行一次验证 + if epoch % VALIDATE_EVERY == 0: + val_loss, val_acc = eval_one_epoch(val_dataloader) + logging.info(f' == Epoch: {epoch} | Validation Loss: {val_loss} | Accuracy: {val_acc.item()}% ==') + if get_rank() == 0: + # 存模型 + ckpt_dir = "./" + PRETRAIN_PATH + if not os.path.exists(ckpt_dir): + os.makedirs(ckpt_dir, exist_ok=True) + ckpt_file = f"pretrain-{epoch}.ckpt" + ckpt_path = os.path.join(ckpt_dir, ckpt_file) + save_checkpoint(model, ckpt_path) + +def forward_pipeline(data, label): + global net_with_loss + return net_with_loss(data,label) + +def forward(data, label): + global model, loss_fn + logits = model(data) + loss = loss_fn(logits, label) + return loss, logits + +def parse(): + parser = argparse.ArgumentParser() + parser.add_argument("--enable_pipeline", type=bool, default=True, help='Local process rank.') + parser.add_argument("--device_id", type=int, default=-1, help='Local process rank.') + parser.add_argument("--bin_num", type=int, default=5, help='Number of bins.') + parser.add_argument("--gene_num", type=int, default=16906, help='Number of genes.') + parser.add_argument("--epoch", type=int, default=100, help='Number of epochs.') + parser.add_argument("--seed", type=int, default=2021, help='Random seed.') + parser.add_argument("--batch_size", type=int, default=4, help='Number of batch size.') + parser.add_argument("--learning_rate", type=float, default=1e-4, help='Learning rate.') + parser.add_argument("--valid_every", type=int, default=1, help='Number of training epochs between twice validation.') + parser.add_argument("--mask_prob", type=float, default=0.15, help='Probability of masking.') + parser.add_argument("--replace_prob", type=float, default=0.9, help='Probability of replacing with [MASK] token for masking.') + parser.add_argument("--pos_embed", type=bool, default=True, help='Using Gene2vec encoding or not.') + parser.add_argument("--data_path", type=str, default='./data/panglao_10000.h5ad', help='Path of data for pretraining.') + parser.add_argument("--model_name", type=str, default='panglao_pretrain', help='Pretrained model name.') + args = parser.parse_args() + return args + +if __name__ == "__main__": + # 创建日志记录器,文件保存日志,同时控制台也输出日志 + logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(levelname)s - %(message)s', + handlers=[ + logging.FileHandler("training_log.log", mode='a'), # 保存日志到文件 + logging.StreamHandler() # 同时输出日志到控制台 + ] + ) + + args = parse() + args.enable_pipeline = False + args.enable_dp = True + if args.enable_pipeline: + ms.set_context(mode=0, device_target="Ascend") + ms.reset_auto_parallel_context() + ms.set_auto_parallel_context(parallel_mode=ms.ParallelMode.SEMI_AUTO_PARALLEL, pipeline_stages=4, pipeline_result_broadcast=True) + init() + ms.set_seed(1) + elif args.enable_dp: + ms.set_context(mode=0, device_target="Ascend", max_device_memory="29GB") + ms.reset_auto_parallel_context() + ms.set_auto_parallel_context(parallel_mode=ms.ParallelMode.DATA_PARALLEL, gradients_mean=True) + init() + ms.set_seed(1) + else: + ms.set_context(mode=0, device_target="Ascend", device_id=1) + + SEED = args.seed + EPOCHS = args.epoch + BATCH_SIZE = args.batch_size + LEARNING_RATE = args.learning_rate + SEQ_LEN = 5000 + VALIDATE_EVERY = args.valid_every + CLASS = args.bin_num + 2 + MASK_PROB = args.mask_prob + REPLACE_PROB = args.replace_prob + RANDOM_TOKEN_PROB = 0. + MASK_TOKEN_ID = CLASS - 1 + PIPELINE = args.enable_pipeline + PAD_TOKEN_ID = CLASS - 1 + MASK_IGNORE_TOKEN_IDS = [0] + POS_EMBED_USING = args.pos_embed + MODEL_NAME = args.model_name + DP = args.enable_dp + train_dataloader, val_dataloader = load_data(args.data_path, CLASS, SEED, BATCH_SIZE, SEQ_LEN, args) + model = build_model(args) + + optimizer,loss_fn = build_optimizer_and_scheduler(model) + if args.enable_pipeline: + global net_with_loss + net_with_loss = nn.PipelineCell(nn.WithLossCell(model, loss_fn), micro_size=4) + net_with_loss.set_train() + global lr_schedule + lr_schedule = ExponentialDecayLR(learning_rate=0.1, decay_rate=0.9, decay_steps=100) + train(optimizer, train_dataloader, val_dataloader) \ No newline at end of file diff --git a/scBERT/run_distribute_finetune.sh b/scBERT/run_distribute_finetune.sh new file mode 100644 index 0000000000000000000000000000000000000000..36e993d62d375b2b43ec77379be8137354c0f208 --- /dev/null +++ b/scBERT/run_distribute_finetune.sh @@ -0,0 +1,9 @@ +#!/bin/bash + +echo "==============================================================================================================" +echo "Please run the script as: " +echo "bash run.sh" +echo "==============================================================================================================" + + +msrun --worker_num=8 --local_worker_num=8 --master_port=8118 --log_dir=msrun_log --join=True --cluster_time_out=300 finetune.py diff --git a/scBERT/run_distribute_pretrain.sh b/scBERT/run_distribute_pretrain.sh new file mode 100644 index 0000000000000000000000000000000000000000..d941a94aa90baa58481deeea50648dc9e7e4de00 --- /dev/null +++ b/scBERT/run_distribute_pretrain.sh @@ -0,0 +1,9 @@ +#!/bin/bash + +echo "==============================================================================================================" +echo "Please run the script as: " +echo "bash run.sh" +echo "==============================================================================================================" + + +msrun --worker_num=8 --local_worker_num=8 --master_port=8118 --log_dir=msrun_log --join=True --cluster_time_out=300 pretrain.py diff --git a/scBERT/utils.py b/scBERT/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..5e864e94adc016999b60efc8b9134ff2bc1a71d5 --- /dev/null +++ b/scBERT/utils.py @@ -0,0 +1,22 @@ +from contextlib import contextmanager + +def exists(val): + return val is not None +def empty(tensor): # mindspore的Tensor.size 返回 张量中元素的个数 + return tensor.size == 0 +def default(val, d): + return val if exists(val) else d +@contextmanager +def null_context(): + yield +def cast_tuple(val): + return (val,) if not isinstance(val, tuple) else val +def find_modules(nn_module, module_type): + return [module for module in nn_module.cells() if isinstance(module, module_type)] # mindspore通过Cell.cells() 获取到子cell。 + +def route_args(router, pos_emb, depth): + routed_args = [ dict() for _ in range(depth)] + + for i in range(depth): + routed_args[depth] = ({"pos_emb":pos_emb},{}) + return routed_args \ No newline at end of file