diff --git a/research/arxiv_paper/README-example.pdf b/research/arxiv_paper/README-example.pdf new file mode 100644 index 0000000000000000000000000000000000000000..e91859948610f3a1fbfbd3db3e8bc7de1020a7f4 Binary files /dev/null and b/research/arxiv_paper/README-example.pdf differ diff --git a/research/arxiv_paper/README.md b/research/arxiv_paper/README.md new file mode 100644 index 0000000000000000000000000000000000000000..3bc0b84a9a7ec463c9a382e7f045c9b43c71e95b --- /dev/null +++ b/research/arxiv_paper/README.md @@ -0,0 +1,10 @@ +## MobileNetV2: A lightweight classification model for home-based sleep apnea screening + +In this study, we propose the use of a lightweight neural network model for learning sleep stage classification from electrocardiogram (ECG) signal feature maps. Subsequently, the number of respiratory events occurring during sleep is determined based on the developed apnea event detection algorithm. The Apnea-Hypopnea Index (AHI) is calculated by dividing the number of respiratory events by the total sleep duration, thereby enabling the assessment of the subject’s risk of sleep apnea disorders. Extensive validation is performed across multiple publicly available datasets, demonstrating high disease detection rates and indicating the model's promising potential for practical applications. + +keywords: Sleep Apnea, MobileNetV2, Wearable Devices + +## Test Environment: +- MindSpore 2.2.14 +- Python 3.9 +- Ascend \ No newline at end of file diff --git a/research/arxiv_paper/code/0-preprocess.ipynb b/research/arxiv_paper/code/0-preprocess.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..4c8cc780bc287d9c7c4445fb845ce529987e477c --- /dev/null +++ b/research/arxiv_paper/code/0-preprocess.ipynb @@ -0,0 +1,448 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "40a02c38", + "metadata": {}, + "source": [ + "## 批量压缩\n", + "step1: 将matlab生成的数据批量压缩为zip \n", + "step2: 将zip导入到mindspore中" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "077d7397", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import zipfile\n", + "\n", + "def zip_folder(folder_path, zip_file_path):\n", + " with zipfile.ZipFile(zip_file_path, 'w', zipfile.ZIP_DEFLATED) as zipf:\n", + " for root, dirs, files in os.walk(folder_path):\n", + " for file in files:\n", + " file_path = os.path.join(root, file)\n", + " relative_path = os.path.relpath(file_path, folder_path)\n", + " zipf.write(file_path, relative_path)\n", + "\n", + "def zip_each_folder(target_dir):\n", + " for item in os.listdir(target_dir):\n", + " folder_path = os.path.join(target_dir, item)\n", + "\n", + " if os.path.isdir(folder_path):\n", + " zip_file_name = item + '.zip'\n", + " zip_file_path = os.path.join(target_dir, zip_file_name)\n", + "\n", + " zip_folder(folder_path, zip_file_path)\n", + "\n", + " print(f\"{item} 已压缩为 {zip_file_name}\")\n", + "\n", + "target_directory = './tryzip'\n", + "zip_each_folder(target_directory)" + ] + }, + { + "cell_type": "markdown", + "id": "c7175441", + "metadata": {}, + "source": [ + "## 解压文件\n", + "step3: 将存储的按人头分的zip文件,解压到savepath" + ] + }, + { + "cell_type": "markdown", + "id": "e366df22", + "metadata": {}, + "source": [ + "训练数据" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4c14f24b", + "metadata": {}, + "outputs": [], + "source": [ + "import zipfile\n", + "import os\n", + "\n", + "\n", + "def unzip(zipath, savefolder):\n", + " '''\n", + " zipath : 待解压文件的路径\n", + " savefolder : 解压后文件存放的文件夹的绝对路径\n", + " '''\n", + " zf = zipfile.ZipFile(zipath) # zipfile 读取压缩文件对象\n", + " zf.extractall(savefolder) # 压缩文件内全部文件解压到输入的文件夹中\n", + " zf.close() # 关闭 zipfile 对象\n", + "\n", + "zipath = os.path.join(os.getcwd(),'1_AP_zip')\n", + "savepath = os.path.join(os.getcwd(),'unzip')\n", + "os.makedirs(savepath)\n", + "\n", + "fileNames = os.listdir(zipath) \n", + "for file in fileNames:\n", + " index = file.index(\".\")\n", + " result = file[:index]\n", + " \n", + " filePath = os.path.join(zipath, file)\n", + " savefolder = os.path.join(savepath, result)\n", + " os.makedirs(savefolder)\n", + " unzip(filePath, savefolder)" + ] + }, + { + "cell_type": "markdown", + "id": "eef80ab3", + "metadata": {}, + "source": [ + "测试数据" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "66ae05f1", + "metadata": {}, + "outputs": [], + "source": [ + "import zipfile\n", + "import os\n", + "\n", + "\n", + "def unzip(zipath, savefolder):\n", + " '''\n", + " zipath : 待解压文件的路径\n", + " savefolder : 解压后文件存放的文件夹的绝对路径\n", + " '''\n", + " zf = zipfile.ZipFile(zipath) # zipfile 读取压缩文件对象\n", + " zf.extractall(savefolder) # 压缩文件内全部文件解压到输入的文件夹中\n", + " zf.close() # 关闭 zipfile 对象\n", + "\n", + "zipath = os.path.join(os.getcwd(),'1_AP_abc')\n", + "savepath = os.path.join(os.getcwd(),'0-AP','For_test1')\n", + "os.makedirs(savepath)\n", + "\n", + "fileNames = os.listdir(zipath) \n", + "for file in fileNames:\n", + " index = file.index(\".\")\n", + " if index == 3:\n", + " result = file[:index]\n", + " filePath = os.path.join(zipath, file)\n", + " savefolder = os.path.join(savepath, result)\n", + " os.makedirs(savefolder)\n", + " print(savefolder)\n", + " unzip(filePath, savefolder)" + ] + }, + { + "cell_type": "markdown", + "id": "bbeb9ba4", + "metadata": {}, + "source": [ + "## 转移文件\n", + "step4:将解压缩的、按人头分的数据,转换为按类别分的data数据集" + ] + }, + { + "cell_type": "markdown", + "id": "10119395", + "metadata": {}, + "source": [ + "针对有类别的数据集" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "04df981b", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import shutil\n", + "\n", + "def CreateDir(path):\n", + " isExists=os.path.exists(path)\n", + " # 判断结果\n", + " if not isExists:\n", + " # 如果不存在则创建目录\n", + " os.makedirs(path) \n", + " print(path+' 目录创建成功')\n", + " else:\n", + " # 如果目录存在则不创建,并提示目录已存在\n", + " print(path+' 目录已存在')\n", + "\n", + "\n", + "def CopyFile(filepath, newPath):\n", + " # 获取当前路径下的文件名,返回List\n", + " fileNames = os.listdir(filepath) \n", + " for file in fileNames:\n", + " # 将文件命加入到当前文件路径后面\n", + " newDir = os.path.join(filepath,file)\n", + " classes = os.listdir(newDir) \n", + " for typefile in classes:\n", + " newDirson = os.path.join(newDir,typefile)\n", + " images = os.listdir(newDirson) \n", + " for image in images:\n", + " newFile = os.path.join(newPath,typefile)\n", + " imagePath = os.path.join(newDirson,image)\n", + " shutil.copy(imagePath, newFile) \n", + "\n", + "if __name__ == \"__main__\":\n", + "\n", + " rootpath = os.getcwd()\n", + " \n", + " frompath = os.path.join(rootpath,'unzip')\n", + " topath = os.path.join(rootpath, 'data_Sas_AP')\n", + " toApath = os.path.join(topath, 'A')\n", + " toNpath = os.path.join(topath, 'N')\n", + " # 创建目标文件夹\n", + " CreateDir(topath)\n", + " CreateDir(toApath)\n", + " CreateDir(toNpath)\n", + " CopyFile(frompath, topath)\n", + " shutil.rmtree(frompath) " + ] + }, + { + "cell_type": "markdown", + "id": "8e886a5b", + "metadata": {}, + "source": [ + "针对无类别的数据集,随机分配到二分类模型中" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "80e86af7", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import shutil\n", + "import random\n", + "from shutil import copy\n", + "\n", + "def CreateDir(path):\n", + " isExists=os.path.exists(path)\n", + " # 判断结果\n", + " if not isExists:\n", + " # 如果不存在则创建目录\n", + " os.makedirs(path) \n", + " print(path+' 目录创建成功')\n", + " else:\n", + " # 如果目录存在则不创建,并提示目录已存在\n", + " print(path+' 目录已存在')\n", + "\n", + "\n", + "def CopyFile(filepath, newPath):\n", + " # 获取当前路径下的文件名,返回List\n", + " fileNames = os.listdir(filepath) \n", + " for file in fileNames:\n", + " # 将文件命加入到当前文件路径后面\n", + " newDir = os.path.join(filepath,file)\n", + " classes = os.listdir(newDir) \n", + " for typefile in classes:\n", + " newDirson = os.path.join(newDir,typefile)\n", + " images = os.listdir(newDirson) \n", + " for image in images:\n", + " newFile = os.path.join(newPath,typefile)\n", + " imagePath = os.path.join(newDirson,image)\n", + " shutil.copy(imagePath, newFile) \n", + "\n", + "if __name__ == \"__main__\":\n", + "\n", + " rootpath = os.getcwd()\n", + " \n", + " frompath = os.path.join(rootpath,'0-AP','For_test2')\n", + " topath = os.path.join(rootpath, '0-AP','For_test_s')\n", + " fileNames = os.listdir(frompath) \n", + " for file in fileNames: \n", + " filepath = os.path.join(frompath,file)\n", + " \n", + " newfileSpath = os.path.join(topath,file,'S')\n", + " newfileNSpath = os.path.join(topath,file,'NS')\n", + " CreateDir(newfileSpath)\n", + " CreateDir(newfileNSpath)\n", + " \n", + " images = os.listdir(filepath)\n", + " num = len(images)\n", + " # 随机采样测试&验证集的索引\n", + " eval_index = random.sample(images, k=int(num*0.3))\n", + " for index, image in enumerate(images):\n", + " if image in eval_index:\n", + " # 将分配至测试&测试集中的文件复制到相应目录\n", + " image_path = os.path.join(frompath, file ,image)\n", + " new_path = os.path.join(newfileNSpath,image)\n", + " copy(image_path, new_path)\n", + " else:\n", + " # 将分配至训练集中的文件复制到相应目录\n", + " image_path = os.path.join(frompath, file ,image)\n", + " new_path = os.path.join(newfileSpath,image)\n", + " copy(image_path, new_path)\n", + " \n" + ] + }, + { + "cell_type": "markdown", + "id": "d6df928f", + "metadata": {}, + "source": [ + "## 拆分文件\n", + "将训练数据划分为train和test" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "844da812", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "from shutil import copy, rmtree\n", + "import random\n", + "import json\n", + "\n", + "\n", + "def mk_file(file_path: str):\n", + " if os.path.exists(file_path):\n", + " # 如果文件夹存在,则先删除原文件夹在重新创建\n", + " rmtree(file_path)\n", + " os.makedirs(file_path)\n", + "\n", + "\n", + "def main():\n", + " # 保证随机可复现\n", + " random.seed(0)\n", + "\n", + " # 将数据集中10%的数据划分到验证集中\n", + " split_rate = 0.1\n", + "\n", + " # 指向你解压后的flower_photos文件夹\n", + " cwd = os.getcwd()\n", + " data_root = os.path.join(cwd, \"data_Sas_AP\")\n", + " assert os.path.exists(data_root), \"path '{}' does not exist.\".format(data_root)\n", + "\n", + " flower_class = [cla for cla in os.listdir(data_root)\n", + " if os.path.isdir(os.path.join(data_root, cla))]\n", + "\n", + " # 排序,保证顺序一致\n", + " flower_class.sort()\n", + " # # 生成类别名称以及对应的数字索引\n", + " # class_indices = dict((k, v) for v, k in enumerate(flower_class))\n", + " # json_str = json.dumps(dict((val, key) for key, val in class_indices.items()), indent=4)\n", + " # with open('class_indices.json', 'w') as json_file:\n", + " # json_file.write(json_str)\n", + "\n", + " cwd = os.getcwd() #返回当前工作目录\n", + " learn_root = os.path.join(cwd,\"model_generator_sas\")\n", + "\n", + " # 建立保存训练集的文件夹\n", + " train_root = os.path.join(learn_root, \"train\")\n", + " mk_file(train_root)\n", + " for cla in flower_class:\n", + " # 建立每个类别对应的文件夹\n", + " mk_file(os.path.join(train_root, cla))\n", + "\n", + " # 建立保存测试集的文件夹\n", + " test_root = os.path.join(learn_root, \"test\")\n", + " mk_file(test_root)\n", + " for cla in flower_class:\n", + " # 建立每个类别对应的文件夹\n", + " mk_file(os.path.join(test_root, cla))\n", + "\n", + "\n", + " for cla in flower_class: # cla:A/N\n", + " cla_path = os.path.join(data_root, cla) \n", + " images = os.listdir(cla_path)\n", + " num = len(images)\n", + " # 随机采样测试&验证集的索引\n", + " eval_index = random.sample(images, k=int(num*split_rate))\n", + " for index, image in enumerate(images):\n", + " if image in eval_index:\n", + " # 将分配至测试&测试集中的文件复制到相应目录\n", + " image_path = os.path.join(data_root, cla, image)\n", + " new_path = os.path.join(test_root, cla)\n", + " copy(image_path, new_path)\n", + " else:\n", + " # 将分配至训练集中的文件复制到相应目录\n", + " image_path = os.path.join(data_root, cla, image)\n", + " new_path = os.path.join(train_root, cla)\n", + " copy(image_path, new_path)\n", + " print(\"\\r[{}] processing [{}/{}]\".format(cla, index+1, num), end=\"\") # processing bar\n", + " print(\"Finish!\")\n", + " \n", + "if __name__ == '__main__':\n", + " main()" + ] + }, + { + "cell_type": "markdown", + "id": "5cc4266f", + "metadata": {}, + "source": [ + "## 复制文件\n", + "将test内的文件复制到train中,最终使用train集进行模型训练" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "76b45354", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import shutil\n", + "\n", + "def copy_folder(source_folder, destination_folder):\n", + " if not os.path.exists(destination_folder):\n", + " os.makedirs(destination_folder)\n", + "\n", + " for item in os.listdir(source_folder):\n", + " source = os.path.join(source_folder, item)\n", + " destination = os.path.join(destination_folder, item)\n", + "\n", + " if os.path.isdir(source):\n", + " copy_folder(source, destination)\n", + " else:\n", + " shutil.copy2(source, destination)\n", + "\n", + "# 使用示例\n", + "source = './model_generator_sas/test'\n", + "destination = './model_generator_sas/train'\n", + "\n", + "copy_folder(source, destination)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python [conda env:.conda-mindspore_py37]", + "language": "python", + "name": "conda-env-.conda-mindspore_py37-py" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.5" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/research/arxiv_paper/code/01-mobilenetV2_train.ipynb b/research/arxiv_paper/code/01-mobilenetV2_train.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..a1d459a71dde97ea5a3564879f7dc65d61d9eb54 --- /dev/null +++ b/research/arxiv_paper/code/01-mobilenetV2_train.ipynb @@ -0,0 +1,741 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "aeb5f217", + "metadata": {}, + "source": [ + "## 下载预训练权重" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "70ab0025", + "metadata": {}, + "outputs": [], + "source": [ + "!pip install download" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f12457ef", + "metadata": {}, + "outputs": [], + "source": [ + "from download import download\n", + "url = \"https://ascend-professional-construction-dataset.obs.cn-north-4.myhuaweicloud.com:443/ComputerVision/mobilenetV2-200_1067.zip\" \n", + "path = download(url, \"./\", kind=\"zip\", replace=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f7d9469a", + "metadata": {}, + "outputs": [], + "source": [ + "import zipfile\n", + "import os\n", + "\n", + "# Define the paths to the zip files\n", + "zip_files = \"mobilenetV2-200_1067.zip\"\n", + "\n", + "# Extract each zip file to the current directory\n", + "with zipfile.ZipFile(zip_file, 'r') as zip_ref:\n", + " # Extract all the contents into the current directory\n", + " zip_ref.extractall(os.getcwd())\n", + " print(f\"Extracted {zip_file} to {os.getcwd()}\")" + ] + }, + { + "cell_type": "markdown", + "id": "c495bf38", + "metadata": {}, + "source": [ + "## 导入库" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a790f318", + "metadata": {}, + "outputs": [], + "source": [ + "import math\n", + "import numpy as np\n", + "import os\n", + "import random\n", + "\n", + "from matplotlib import pyplot as plt\n", + "from easydict import EasyDict\n", + "from PIL import Image\n", + "import numpy as np\n", + "import mindspore.nn as nn\n", + "from mindspore import ops as P\n", + "from mindspore.ops import add\n", + "from mindspore import Tensor\n", + "import mindspore.common.dtype as mstype\n", + "import mindspore.dataset as de\n", + "import mindspore.dataset.vision as C\n", + "import mindspore.dataset.transforms as C2\n", + "import mindspore as ms\n", + "from mindspore import set_context, nn, Tensor, load_checkpoint, save_checkpoint, export\n", + "from mindspore.train import Model\n", + "from mindspore.train import Callback, LossMonitor, ModelCheckpoint, CheckpointConfig\n", + "\n", + "# Log level includes 3(ERROR), 2(WARNING), 1(INFO), 0(DEBUG).\n", + "os.environ['GLOG_v'] = '3' # Set logging level\n", + "set_context(mode=ms.GRAPH_MODE, device_target=\"Ascend\", device_id=0) # 设置采用图模式执行,设备为Ascend\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2ad8fb30", + "metadata": {}, + "outputs": [], + "source": [ + "# 垃圾分类数据集标签,以及用于标签映射的字典。\n", + "garbage_classes = {\n", + " '第一部分': ['NS', 'S'],\n", + " '第二部分': ['塑料瓶盖', '饮料瓶', '玻璃瓶', '纸杯', '塑料袋', '垃圾袋']\n", + "}\n", + "\n", + "class_cn = ['NS', 'S']\n", + "class_en = ['NS', 'S']\n", + "index_en = {'NS': 0, 'S': 1}\n", + "\n", + "# 配置参数\n", + "config = EasyDict({\n", + " \"num_classes\": 2,\n", + " \"image_height\": 224,\n", + " \"image_width\": 224,\n", + " \"data_split\": (0.9, 0.1),\n", + " \"backbone_out_channels\": 1280,\n", + " \"batch_size\": 32,\n", + " \"eval_batch_size\": 8,\n", + " \"epochs\": 20,\n", + " \"lr_max\": 0.0005,\n", + " \"momentum\": 0.9,\n", + " \"weight_decay\": 1e-4,\n", + " \"save_ckpt_epochs\": 1,\n", + " \"save_ckpt_path\": \"./ckpt2\", #./表示当前目录下\n", + " # \"dataset_path\": \"./data_en\",\n", + " \"dataset_path\": \"./model_generator\", # Updated to new dataset path\n", + " \"class_index\": index_en,\n", + " \"pretrained_ckpt\": \"./mobilenetV2-200_1067.ckpt\"\n", + "})" + ] + }, + { + "cell_type": "markdown", + "id": "ec09e781", + "metadata": {}, + "source": [ + "## 数据集定义" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7f6a80c9", + "metadata": {}, + "outputs": [], + "source": [ + "def create_dataset(dataset_path, config, training=True, buffer_size=1000):\n", + " \"\"\" \n", + " create a train or eval dataset\n", + "\n", + " Args:\n", + " dataset_path (string): the path of dataset.\n", + " config (struct): the config of train and eval in different platform.\n", + "\n", + " Returns:\n", + " train_dataset, val_dataset\n", + " \"\"\"\n", + " data_path = os.path.join(dataset_path, 'train' if training else 'test')\n", + " ds = de.ImageFolderDataset(data_path, num_parallel_workers=4, class_indexing=config.class_index)\n", + " resize_height = config.image_height\n", + " resize_width = config.image_width\n", + "\n", + " normalize_op = C.Normalize(mean=[0.485*255, 0.456*255, 0.406*255], std=[0.229*255, 0.224*255, 0.225*255])\n", + " change_swap_op = C.HWC2CHW()\n", + " type_cast_op = C2.TypeCast(mstype.int32)\n", + "\n", + " if training:\n", + " crop_decode_resize = C.RandomCropDecodeResize(resize_height, scale=(0.08, 1.0), ratio=(0.75, 1.333))\n", + " horizontal_flip_op = C.RandomHorizontalFlip(prob=0.5)\n", + " #color_adjust = C.RandomColorAdjust(brightness=0.4, contrast=0.4, saturation=0.4)\n", + "\n", + " #train_trans = [crop_decode_resize, horizontal_flip_op, color_adjust, normalize_op, change_swap_op]\n", + " train_trans = [crop_decode_resize, horizontal_flip_op, normalize_op, change_swap_op]\n", + " train_ds = ds.map(input_columns=\"image\", operations=train_trans, num_parallel_workers=4)\n", + " train_ds = train_ds.map(input_columns=\"label\", operations=type_cast_op, num_parallel_workers=4)\n", + "\n", + " train_ds = train_ds.shuffle(buffer_size=buffer_size)\n", + " ds = train_ds.batch(config.batch_size, drop_remainder=True)\n", + "\n", + " else:\n", + " decode_op = C.Decode()\n", + " resize_op = C.Resize((int(resize_width * 0.875), int(resize_height * 0.875)))\n", + " center_crop = C.CenterCrop(resize_width)\n", + "\n", + " eval_trans = [decode_op, resize_op, center_crop, normalize_op, change_swap_op]\n", + " eval_ds = ds.map(input_columns=\"image\", operations=eval_trans, num_parallel_workers=4)\n", + " eval_ds = eval_ds.map(input_columns=\"label\", operations=type_cast_op, num_parallel_workers=4)\n", + " ds = eval_ds.batch(config.eval_batch_size, drop_remainder=True)\n", + "\n", + " return ds" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d08f590d", + "metadata": {}, + "outputs": [], + "source": [ + "# 显示处理过的前4张图片\n", + "ds = create_dataset(dataset_path=config.dataset_path, config=config, training=True)\n", + "print(ds.get_dataset_size())\n", + "data = ds.create_dict_iterator(output_numpy=True)._get_next()\n", + "images = data['image']\n", + "labels = data['label']\n", + "\n", + "for i in range(1, 5):\n", + " plt.subplot(2, 2, i)\n", + " plt.imshow(np.transpose(images[i], (1, 2, 0)))\n", + " plt.title(f'标签: {class_en[labels[i]]}')\n", + " plt.xticks([])\n", + " plt.yticks([])\n", + "\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "8536adf3", + "metadata": {}, + "source": [ + "## MobileNetV2模型搭建" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "41b731ea", + "metadata": {}, + "outputs": [], + "source": [ + "import math\n", + "from mindspore import nn, Tensor, ops as P\n", + "import numpy as np\n", + "\n", + "__all__ = ['MobileNetV2', 'MobileNetV2Backbone', 'MobileNetV2Head', 'mobilenet_v2']\n", + "\n", + "def _make_divisible(v, divisor, min_value=None):\n", + " if min_value is None:\n", + " min_value = divisor\n", + " new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)\n", + " if new_v < 0.9 * v:\n", + " new_v += divisor\n", + " return new_v\n", + "\n", + "class GlobalAvgPooling(nn.Cell):\n", + " def __init__(self):\n", + " super(GlobalAvgPooling, self).__init__()\n", + "\n", + " def construct(self, x):\n", + " x = P.ReduceMean()(x, (2, 3))\n", + " return x\n", + "\n", + "class ConvBNReLU(nn.Cell):\n", + " def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1):\n", + " super(ConvBNReLU, self).__init__()\n", + " padding = (kernel_size - 1) // 2\n", + " in_channels = in_planes\n", + " out_channels = out_planes\n", + " if groups == 1:\n", + " conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, pad_mode='pad', padding=padding)\n", + " else:\n", + " out_channels = in_planes\n", + " conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, pad_mode='pad',\n", + " padding=padding, group=in_channels)\n", + "\n", + " layers = [conv, nn.BatchNorm2d(out_planes), nn.ReLU6()]\n", + " self.features = nn.SequentialCell(layers)\n", + "\n", + " def construct(self, x):\n", + " output = self.features(x)\n", + " return output\n", + "\n", + "class InvertedResidual(nn.Cell):\n", + " def __init__(self, inp, oup, stride, expand_ratio):\n", + " super(InvertedResidual, self).__init__()\n", + " assert stride in [1, 2]\n", + "\n", + " hidden_dim = int(round(inp * expand_ratio))\n", + " self.use_res_connect = stride == 1 and inp == oup\n", + "\n", + " layers = []\n", + " if expand_ratio != 1:\n", + " layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1))\n", + " layers.extend([\n", + " ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim),\n", + " nn.Conv2d(hidden_dim, oup, kernel_size=1, stride=1, has_bias=False),\n", + " nn.BatchNorm2d(oup),\n", + " ])\n", + " self.conv = nn.SequentialCell(layers)\n", + " self.cast = P.Cast()\n", + "\n", + " def construct(self, x):\n", + " identity = x\n", + " x = self.conv(x)\n", + " if self.use_res_connect:\n", + " return P.Add()(identity, x)\n", + " return x\n", + "\n", + "class MobileNetV2Backbone(nn.Cell):\n", + " def __init__(self, width_mult=1., inverted_residual_setting=None, round_nearest=8,\n", + " input_channel=32, last_channel=1280):\n", + " super(MobileNetV2Backbone, self).__init__()\n", + " block = InvertedResidual\n", + " self.cfgs = inverted_residual_setting or [\n", + " [1, 16, 1, 1],\n", + " [6, 24, 2, 2],\n", + " [6, 32, 3, 2],\n", + " [6, 64, 4, 2],\n", + " [6, 96, 3, 1],\n", + " [6, 160, 3, 2],\n", + " [6, 320, 1, 1],\n", + " ]\n", + "\n", + " input_channel = _make_divisible(input_channel * width_mult, round_nearest)\n", + " self.out_channels = _make_divisible(last_channel * max(1.0, width_mult), round_nearest)\n", + " features = [ConvBNReLU(3, input_channel, stride=2)]\n", + " \n", + " for t, c, n, s in self.cfgs:\n", + " output_channel = _make_divisible(c * width_mult, round_nearest)\n", + " for i in range(n):\n", + " stride = s if i == 0 else 1\n", + " features.append(block(input_channel, output_channel, stride, expand_ratio=t))\n", + " input_channel = output_channel\n", + " features.append(ConvBNReLU(input_channel, self.out_channels, kernel_size=1))\n", + " self.features = nn.SequentialCell(features)\n", + " self._initialize_weights()\n", + "\n", + " def construct(self, x):\n", + " x = self.features(x)\n", + " return x\n", + "\n", + " def _initialize_weights(self):\n", + " self.init_parameters_data()\n", + " for _, m in self.cells_and_names():\n", + " if isinstance(m, nn.Conv2d):\n", + " n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels\n", + " m.weight.set_data(Tensor(np.random.normal(0, np.sqrt(2. / n), m.weight.data.shape).astype(\"float32\")))\n", + " if m.bias is not None:\n", + " m.bias.set_data(Tensor(np.zeros(m.bias.data.shape, dtype=\"float32\")))\n", + " elif isinstance(m, nn.BatchNorm2d):\n", + " m.gamma.set_data(Tensor(np.ones(m.gamma.data.shape, dtype=\"float32\")))\n", + " m.beta.set_data(Tensor(np.zeros(m.beta.data.shape, dtype=\"float32\")))\n", + "\n", + " @property\n", + " def get_features(self):\n", + " return self.features\n", + "\n", + "class MobileNetV2Head(nn.Cell):\n", + " def __init__(self, input_channel=1280, num_classes=1000, has_dropout=False, activation=\"None\"):\n", + " super(MobileNetV2Head, self).__init__()\n", + " head = ([GlobalAvgPooling(), nn.Dense(input_channel, num_classes, has_bias=True)] if not has_dropout else\n", + " [GlobalAvgPooling(), nn.Dropout(0.2), nn.Dense(input_channel, num_classes, has_bias=True)])\n", + " self.head = nn.SequentialCell(head)\n", + " self.need_activation = True\n", + " if activation == \"Sigmoid\":\n", + " self.activation = nn.Sigmoid()\n", + " elif activation == \"Softmax\":\n", + " self.activation = nn.Softmax()\n", + " else:\n", + " self.need_activation = False\n", + " self._initialize_weights()\n", + "\n", + " def construct(self, x):\n", + " x = self.head(x)\n", + " if self.need_activation:\n", + " x = self.activation(x)\n", + " return x\n", + "\n", + " def _initialize_weights(self):\n", + " self.init_parameters_data()\n", + " for _, m in self.cells_and_names():\n", + " if isinstance(m, nn.Dense):\n", + " m.weight.set_data(Tensor(np.random.normal(0, 0.01, m.weight.data.shape).astype(\"float32\")))\n", + " if m.bias is not None:\n", + " m.bias.set_data(Tensor(np.zeros(m.bias.data.shape, dtype=\"float32\")))\n", + "\n", + " @property\n", + " def get_head(self):\n", + " return self.head\n", + "\n", + "class MobileNetV2(nn.Cell):\n", + " def __init__(self, num_classes=1000, width_mult=1., has_dropout=False, inverted_residual_setting=None,\n", + " round_nearest=8, input_channel=32, last_channel=1280):\n", + " super(MobileNetV2, self).__init__()\n", + " self.backbone = MobileNetV2Backbone(width_mult=width_mult, inverted_residual_setting=inverted_residual_setting,\n", + " round_nearest=round_nearest, input_channel=input_channel,\n", + " last_channel=last_channel).get_features\n", + " self.head = MobileNetV2Head(input_channel=self.backbone.out_channels, num_classes=num_classes,\n", + " has_dropout=has_dropout).get_head\n", + "\n", + " def construct(self, x):\n", + " x = self.backbone(x)\n", + " x = self.head(x)\n", + " return x\n", + "\n", + "class MobileNetV2Combine(nn.Cell):\n", + " def __init__(self, backbone, head):\n", + " super(MobileNetV2Combine, self).__init__(auto_prefix=False)\n", + " self.backbone = backbone\n", + " self.head = head\n", + "\n", + " def construct(self, x):\n", + " x = self.backbone(x)\n", + " x = self.head(x)\n", + " return x\n", + "\n", + "def mobilenet_v2(backbone, head):\n", + " return MobileNetV2Combine(backbone, head)\n", + "\n", + "def cosine_lr_schedule(total_steps, lr_init, lr_end, lr_max, warmup_steps):\n", + " \"\"\"\n", + " Generate learning rate array with a cosine decay and linear warmup.\n", + "\n", + " Args:\n", + " total_steps (int): Total training steps.\n", + " lr_init (float): Initial learning rate.\n", + " lr_end (float): Final learning rate.\n", + " lr_max (float): Maximum learning rate.\n", + " warmup_steps (int): Number of warmup steps.\n", + "\n", + " Returns:\n", + " list: Learning rate array.\n", + " \"\"\"\n", + " lr_init, lr_end, lr_max = float(lr_init), float(lr_end), float(lr_max)\n", + " decay_steps = total_steps - warmup_steps\n", + " lr_all_steps = []\n", + " inc_per_step = (lr_max - lr_init) / warmup_steps if warmup_steps else 0\n", + " for i in range(total_steps):\n", + " if i < warmup_steps:\n", + " lr = lr_init + inc_per_step * (i + 1)\n", + " else:\n", + " cosine_decay = 0.5 * (1 + math.cos(math.pi * (i - warmup_steps) / decay_steps))\n", + " lr = (lr_max - lr_end) * cosine_decay + lr_end\n", + " lr_all_steps.append(lr)\n", + "\n", + " return lr_all_steps" + ] + }, + { + "cell_type": "markdown", + "id": "621fe9bc", + "metadata": {}, + "source": [ + "添加检查点 Checkpoint" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ee14d5e6", + "metadata": {}, + "outputs": [], + "source": [ + "def switch_precision(net, data_type):\n", + " if ms.get_context('device_target') == \"Ascend\":\n", + " net.to_float(data_type)\n", + " for _, cell in net.cells_and_names():\n", + " if isinstance(cell, nn.Dense):\n", + " cell.to_float(ms.float32)" + ] + }, + { + "cell_type": "markdown", + "id": "c4cac423", + "metadata": {}, + "source": [ + "## 模型训练与测试" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e2556e74", + "metadata": {}, + "outputs": [], + "source": [ + "from mindspore.amp import FixedLossScaleManager\n", + "from mindspore import save_checkpoint\n", + "import os\n", + "import mindspore as ms\n", + "\n", + "# 设置绝对路径以避免路径解析问题\n", + "CKPT_PATH = os.path.abspath(\"./ckpt\")\n", + "os.makedirs(CKPT_PATH, exist_ok=True)\n", + "\n", + "LOSS_SCALE = 1024\n", + "\n", + "# 加载数据集\n", + "train_dataset = create_dataset(dataset_path=config.dataset_path, config=config)\n", + "eval_dataset = create_dataset(dataset_path=config.dataset_path, config=config)\n", + "step_size = train_dataset.get_dataset_size()\n", + "\n", + "# 设置模型\n", + "backbone = MobileNetV2Backbone()\n", + "\n", + "# 冻结backbone的参数(如果不需要训练)\n", + "for param in backbone.get_parameters():\n", + " param.requires_grad = False\n", + "\n", + "# 从预训练模型中加载参数\n", + "load_checkpoint(config.pretrained_ckpt, backbone)\n", + "\n", + "head = MobileNetV2Head(input_channel=backbone.out_channels, num_classes=config.num_classes)\n", + "network = mobilenet_v2(backbone, head)\n", + "\n", + "# 定义损失函数、优化器和学习率调度\n", + "loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')\n", + "loss_scale = FixedLossScaleManager(LOSS_SCALE, drop_overflow_update=False)\n", + "lrs = cosine_lr_schedule(config.epochs * step_size, lr_init=0.0, lr_end=1e-5, lr_max=config.lr_max, warmup_steps=5)\n", + "opt = nn.Momentum(network.trainable_params(), learning_rate=lrs, momentum=config.momentum, weight_decay=config.weight_decay, loss_scale=LOSS_SCALE)\n", + "\n", + "# 训练循环函数\n", + "def train_loop(model, dataset, loss_fn, optimizer):\n", + " def forward_fn(data, label):\n", + " logits = model(data)\n", + " loss = loss_fn(logits, label)\n", + " return loss\n", + "\n", + " grad_fn = ms.value_and_grad(forward_fn, None, optimizer.parameters)\n", + "\n", + " def train_step(data, label):\n", + " loss, grads = grad_fn(data, label)\n", + " optimizer(grads)\n", + " return loss\n", + "\n", + " size = dataset.get_dataset_size()\n", + " model.set_train()\n", + " for batch, (data, label) in enumerate(dataset.create_tuple_iterator()):\n", + " loss = train_step(data, label)\n", + "\n", + " if batch % 10 == 0:\n", + " loss_val, current = loss.asnumpy(), batch\n", + " print(f\"loss: {loss_val:>7f} [{current:>3d}/{size:>3d}]\")\n", + "\n", + "# 测试循环函数\n", + "def test_loop(model, dataset, loss_fn):\n", + " num_batches = dataset.get_dataset_size()\n", + " model.set_train(False)\n", + " total, test_loss, correct = 0, 0, 0\n", + " for data, label in dataset.create_tuple_iterator():\n", + " pred = model(data)\n", + " total += data.shape[0]\n", + " test_loss += loss_fn(pred, label).asnumpy()\n", + " correct += (pred.argmax(axis=1) == label).asnumpy().sum()\n", + " test_loss /= num_batches\n", + " correct /= total\n", + " print(f\"Test: \\n Accuracy: {(100 * correct):>0.1f}%, Avg loss: {test_loss:>8f} \\n\")\n", + "\n", + "# 训练与评估\n", + "print(\"============== Starting Training ==============\")\n", + "epochs = config.epochs\n", + "for t in range(epochs):\n", + " print(f\"Epoch {t + 1}\\n-------------------------------\")\n", + " train_loop(network, train_dataset, loss, opt)\n", + " save_checkpoint(network, os.path.join(CKPT_PATH, \"save_mobilenetV2_model.ckpt\"))\n", + " test_loop(network, eval_dataset, loss)\n", + "print(\"Done!\")" + ] + }, + { + "cell_type": "markdown", + "id": "7630eb4e", + "metadata": {}, + "source": [ + "## 模型预测" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3ae0e051", + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import os\n", + "from PIL import Image\n", + "from mindspore import Tensor, load_checkpoint\n", + "from sklearn.metrics import confusion_matrix, roc_curve, auc\n", + "import matplotlib.pyplot as plt\n", + "\n", + "\n", + "def image_process(image):\n", + " \"\"\"Process one image at a time.\n", + "\n", + " Args:\n", + " image: shape (H, W, C)\n", + " \"\"\"\n", + " mean = [0.485 * 255, 0.456 * 255, 0.406 * 255]\n", + " std = [0.229 * 255, 0.224 * 255, 0.225 * 255]\n", + " image = (np.array(image) - mean) / std\n", + " image = image.transpose((2, 0, 1)) # Change to (C, H, W) format\n", + " img_tensor = Tensor(np.array([image], np.float32)) # Add batch dimension\n", + " return img_tensor\n", + "\n", + "def infer_one(network, image_path):\n", + " \"\"\"Infer a single image and return predicted label.\"\"\"\n", + " image = Image.open(image_path).resize((config.image_height, config.image_width))\n", + " logits = network(image_process(image))\n", + " pred = np.argmax(logits.asnumpy(), axis=1)[0]\n", + " return pred\n", + "\n", + "def infer_folder(network, folder_path, label_map):\n", + " \"\"\"Infer all images in a folder.\"\"\"\n", + " true_labels = []\n", + " pred_labels = []\n", + " \n", + " for class_name, label in label_map.items():\n", + " class_folder = os.path.join(folder_path, class_name)\n", + " if not os.path.exists(class_folder):\n", + " continue\n", + " \n", + " for image_name in os.listdir(class_folder):\n", + " image_path = os.path.join(class_folder, image_name)\n", + " true_labels.append(label)\n", + " pred_labels.append(infer_one(network, image_path))\n", + " \n", + " return true_labels, pred_labels\n", + "\n", + "def plot_metrics(true_labels, pred_labels, label_map):\n", + " \"\"\"Plot confusion matrix and ROC curve.\"\"\"\n", + " # Confusion Matrix\n", + " cm = confusion_matrix(true_labels, pred_labels)\n", + " print(\"Confusion Matrix:\\n\", cm)\n", + " \n", + " plt.figure(figsize=(10, 5))\n", + " plt.subplot(1, 2, 1)\n", + " plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)\n", + " plt.title(\"Confusion Matrix\")\n", + " plt.colorbar()\n", + " plt.xticks(ticks=np.arange(len(label_map)), labels=label_map.keys(), rotation=45)\n", + " plt.yticks(ticks=np.arange(len(label_map)), labels=label_map.keys())\n", + " plt.ylabel(\"True Label\")\n", + " plt.xlabel(\"Predicted Label\")\n", + " \n", + " # ROC Curve\n", + " n_classes = len(label_map)\n", + " y_true = np.eye(n_classes)[true_labels] # One-hot encode true labels\n", + " y_pred_proba = np.eye(n_classes)[pred_labels] # Simulated probabilities for simplicity\n", + "\n", + " plt.subplot(1, 2, 2)\n", + " for i, label in enumerate(label_map.keys()):\n", + " fpr, tpr, _ = roc_curve(y_true[:, i], y_pred_proba[:, i])\n", + " roc_auc = auc(fpr, tpr)\n", + " plt.plot(fpr, tpr, label=f\"Class {label} (AUC = {roc_auc:.2f})\")\n", + "\n", + " plt.plot([0, 1], [0, 1], \"k--\")\n", + " plt.title(\"ROC Curve\")\n", + " plt.xlabel(\"False Positive Rate\")\n", + " plt.ylabel(\"True Positive Rate\")\n", + " plt.legend(loc=\"lower right\")\n", + " plt.tight_layout()\n", + " plt.show()\n", + "\n", + "def infer_sleep(folder_path):\n", + " # Load network\n", + " backbone = MobileNetV2Backbone(last_channel=config.backbone_out_channels)\n", + " head = MobileNetV2Head(input_channel=backbone.out_channels, num_classes=config.num_classes)\n", + " network = mobilenet_v2(backbone, head)\n", + " load_checkpoint(os.path.join(config.save_ckpt_path, CKPT), network)\n", + " \n", + " # Define label mapping (adjust according to your dataset)\n", + " label_map = {\"NS\": 0, \"S\": 1} # Example class mapping\n", + " \n", + " # Folder containing subfolders for each class\n", + " #folder_path = \"./For_test2/\" +file_name\n", + " \n", + " # Perform inference and compute metrics\n", + " true_labels, pred_labels = infer_folder(network, folder_path, label_map)\n", + " S_num = sum(pred_labels)\n", + " print(\"Sleep_time: %.2f min\" %(S_num/2))\n", + " return S_num/2\n", + " #plot_metrics(true_labels, pred_labels, label_map)\n", + "\n", + "def infer_sas(folder_path):\n", + " # Load network\n", + " backbone = MobileNetV2Backbone(last_channel=config.backbone_out_channels)\n", + " head = MobileNetV2Head(input_channel=backbone.out_channels, num_classes=config.num_classes)\n", + " network = mobilenet_v2(backbone, head)\n", + " load_checkpoint(os.path.join(config.save_ckpt_path, CKPT), network)\n", + " \n", + " # Define label mapping (adjust according to your dataset)\n", + " label_map = {\"N\": 0, \"A\": 1} # Example class mapping\n", + " \n", + " # Perform inference and compute metrics\n", + " true_labels, pred_labels = infer_folder(network, folder_path, label_map)\n", + " A_num = sum(pred_labels)\n", + " print(\"A_num: %d\" %(A_num))\n", + " return A_num\n", + " \n", + "# Run inference\n", + "mother_name = './0-AP'\n", + "folder_path = mother_name + \"/For_test_s\"\n", + "config.save_ckpt_path = \"./ckpt2\"\n", + "CKPT = \"save_mobilenetV2_SLEEP_model.ckpt\"\n", + "\n", + "# 获取当前路径下的文件名,返回List\n", + "fileNames = os.listdir(folder_path) \n", + "for file in fileNames:\n", + " # 将文件命加入到当前文件路径后面\n", + " newDir = os.path.join(folder_path,file)\n", + " time_60s = infer_sleep(newDir)\n", + " print(\"%s 睡眠时长= %.2f min\" %(file, time_60s))\n", + "\n", + "\n", + "\n", + "\n", + "# folder_path = mother_name + \"./For_test/\" +file_name\n", + "# config.save_ckpt_path = \"./ckpt\"\n", + "# CKPT = \"save_mobilenetV2_model.ckpt\"\n", + "# sas_num = infer_sas(folder_path)\n", + "\n", + "# print(\"AHI = %.2f\" %(sas_num*120/time_30s))" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python [conda env:.conda-mindspore_py37]", + "language": "python", + "name": "conda-env-.conda-mindspore_py37-py" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.5" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}