diff --git a/tutorials/source_zh_cn/advanced_use/mobilenetv2_finetune.md b/tutorials/source_zh_cn/advanced_use/mobilenetv2_finetune.md new file mode 100644 index 0000000000000000000000000000000000000000..7a7f8ac0e6ec56d65663d68a96578fb1943ac831 --- /dev/null +++ b/tutorials/source_zh_cn/advanced_use/mobilenetv2_finetune.md @@ -0,0 +1,282 @@ +#
MobileNetV2 增量学习
+ +`Ascend` `GPU` `模型开发` `中级` `高级` + + + +- [增量学习](#增量学习) + - [概述](#概述) + - [任务描述及准备](#任务描述及准备) + - [环境配置](#环境配置) + - [克隆代码](#克隆代码) + - [准备预训练模型](#准备预训练模型) + - [数据准备](#数据准备) + - [预训练模型加载代码详解](#预训练模型加载代码详解) + - [训练参数简介](#训练参数简介) + - [加载训练](#加载训练) + -[开始增量训练](#开始增量训练) + -[调整节点数量](#调整节点数量) + -[训练结果](#训练结果) + -[验证增量训练模型](#验证增量训练模型) + + + +   + + +## 概述 + +计算机视觉任务中,从头开始训练一个网络耗时巨大,需要大量计算能力。预训练模型选择的常见的OpenImage,ImageNet,VOC,COCO等公开大型数据集,规模达到几十万甚至超过上百万张。大部分任务数据规模较大,训练网络模型时,如果不使用预训练模型,从头开始训练网络,需要消耗大量的时间与计算能力,模型容易陷入局部极小值和过拟合。因此大部分任务都会选择预训练模型,在其上做增量学习。 + +MindSpore 是一个多元化的机器学习框架。既可以在手机等端侧和PC等设备上运行,也可以在云上的服务器集群上运行。下面教程以MobileNetV2为例,介绍如何在MindSpore框架中做增量学习。 + +## 任务描述及准备 + +### 环境配置 + +若在华为云环境上运行,不需要安装MindSpore框架和配置Ascend AI处理器,可以跳过本小节。若在本地环境运行,需要安装MindSpore框架,配置Ascend AI处理器/GPU。 + +1. 安装Minspore 框架 + 在Euler、Ubuntu或者Windows等系统上需要系统和处理器架构[安装对应版本MindSporo框架](https://www.mindspore.cn/install)。 + +2. 配置Ascend 环境 + 以Ascend 910 AI处理器为例,1个8个处理器环境的json配置文件hccl_config.json_示例如下。单/多处理器环境可以根据以下示例调整`"server_count"`与`"device"`: + + ```json + { + "version": "1.0", + "server_count": "1", + "server_list": [ + { + "server_id": "10.155.111.140", + "device": [ + {"device_id": "0","device_ip": "192.1.27.6","rank_id": "0"}, + {"device_id": "1","device_ip": "192.2.27.6","rank_id": "1"}, + {"device_id": "2","device_ip": "192.3.27.6","rank_id": "2"}, + {"device_id": "3","device_ip": "192.4.27.6","rank_id": "3"}, + {"device_id": "4","device_ip": "192.1.27.7","rank_id": "4"}, + {"device_id": "5","device_ip": "192.2.27.7","rank_id": "5"}, + {"device_id": "6","device_ip": "192.3.27.7","rank_id": "6"}, + {"device_id": "7","device_ip": "192.4.27.7","rank_id": "7"}], + "host_nic_ip": "reserve" + } + ], + "status": "completed" + } + ``` + + 使用Ascend AI处理器时,在代码中,需要在调用Ascend 芯片开始训练或测试前,按照如下代码设置: + + ```python + import os + from mindspore import context + from mindspore.communication.management import init + + if __name__ == "__main__": + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=int(os.environ["DEVICE_ID"])) + init() + ... + ``` + +3. 配置GPU环境 + 使用GPU时,在代码中,需要在调用GPU开始训练或测试前,按照如下代码设置: + + ```python + from mindspore import context + from mindspore.communication.management import init + + if __name__ == "__main__": + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + init("nccl") + ... + ``` + +### 克隆代码 + +在Gitee 中克隆[MindSpore开源项目仓库](https://gitee.com/mindspore/mindspore.git),进入`./model_zoo/official/cv/mobilenetv2/` + +```bash +git clone git@gitee.com:/mindspore.git +cd ./mindspore/model_zoo/official/cv/mobilenetv2 +``` + +训练时运行`launch.py`,根据分配的Ascend AI处理器/GPU数量,启动多个进程运行`train.py`,每一个进程分配对应的一个处理器。不能直接运行`train.py`训练。验证与测试时直接运行`eval.py`。 + + 代码结构 + + └─MobileNet + ├─checkpoint + ├─scr + ├─launch.py + ├─train.py + ├─eval.py + └─README.md + +### 准备预训练模型 + +下载预训练模型到以下目录: +`./mindspore/model_zoo/official/cv/mobilenetv2/pretrain_checkpoint/[pretrain_checkpoint_file_name]` + +```python +cp -v [pretrain_checkpoint_file_path] ./mindspore/model_zoo/official/cv/mobilenetv2/pretrain_checkpoint/ +``` + +### 数据准备 + +准备ImageFolder格式管理的数据集,并且在运行`launch.py`时加入`--dataset_path [dataset_path]`参数: + + 数据集结构 + + └─ImageFolder + ├─train + │ ├─class1Folder + │ ├─class2Folder + │ └─...... + └─eval + ├─class1Folder + ├─class2Folder + └─...... + +## 预训练模型加载代码详解 + +在增量学习时,需要加载预训练模型。不同数据集和任务中特征提取层(卷积层)分布趋于一致,但是特征向量的组合(全连接层)不相同,分类数量(全连接层output_size)通常也不一致。在增量学习时,只加载特征提取层参数,抛弃全连接层参数不加载。 + +在构建Network时,MindSpore会根据Python变量命名,自动隐式设置图结构中参数的命名。以如下MobileNetV2 head层初始化代码为例,全局平均池化、全连接层、Dropout层共同构成了head层,全局平均池化与Dropout层没有参数,head层中只有卷积层包含参数。全连接层的param.name是“head.1”,全连接层的权重和偏置分别为“head.1.weight”和“head.1.bias”。 + +```python +head = ([GlobalAvgPooling(), nn.Dense(self.out_channels, num_classes, has_bias=True)] if not has_dropout else [GlobalAvgPooling(), nn.Dropout(0.2), nn.Dense(self.out_channels, num_classes, has_bias=True)]) +self.head = nn.SequentialCell(head) +self._initialize_weights() + +``` + +若用户自定义Network,可以在初始化net后加入如下代码提取所有参数的命名,打印出来以方便查找全连接层参数命名,再使用`param_dict.pop(param.name)`来抛掉全连接层权重。 + +```python +for param in net.get_parameters(): + print(param.name) +``` + +在模型加载之前,首先根据Network结构构建网络并且初始化网络权重。以如下mobilenetV2增量训练加载模型代码为例,line 6将pretrain_checkpoint_file加载入param_dict;line 8-12,将param_dict中全连接层权重抛掉;line 14加载预训练模型中卷积、BN等特征提取层模型权重,line 15-17冻结net中全连接层之外的全部参数梯度,在后续训练过程中不再训练特征提取层权重,只训练全连接层权重。 + +```python + 1: from mindspore.train.serialization import load_checkpoint, load_param_into_net + 2: + 3: def model_fine_tune(net, pretrained_path): + 4: if pretrained_path is Node: + 5: return + 6: param_dict = load_checkpoint(pretrained_path) + 7: #use pop func to discard the pretrained weight、bias and optimization params of full connection. + 8: param_dict.pop('head.1.weight') + 9: param_dict.pop('head.1.bias') +10: param_dict.pop('moments.head.1.weight') +11: param_dict.pop('moments.head.1.bias') +12: +13: #set all the params' requires_grad except the full connection laysers' false to freeze them in the fine tune process. +14: load_param_into_net(net, param_dict) +15: for param in net.get_parameters(): +16: if param.name not in ("head.1.weight", "head.1.bias"): +17: param.requires_grad(value=False) +``` + +## 训练参数简介 + +在运行训练与验证`launch.py`与`eval.py`时读取训练参数。 + +```python +python launch.py --dataset_path [dataset_path] --batch_size 128 --epoch_size 200 +``` + +`--platform`:设备,默认为“Ascend”,可以设置为“GPU” +`--nproc_per_node`:每个节点(一台服务器/PC相当于一个节点)进程数量,建议设置为机器上Ascend 芯片数量或CPU,GPU数量 +`--visible_devices`:字符串格式的的设备ID,训练将会根据visible_devices将程绑定到对应ID的设备上,多个设备ID之间使用','分隔,建议ID数量与进程数量相同 +`--hccl_config_path`:platform选择Ascend时需要配置Ascend的配置Json文件 +`--training_scrip`:训练脚本,默认`./train.py`,通常不需要传入新的值 +`--dataset_path`:训练与验证数据集地址,无默认值,用户训练/验证时必须输入 +`--pretrain_checkpoint_path`:增量训练时,需要传入checkpoint文件路径以加载预训练好的模型参数 +`--num_classes`:模型最终分类种类,整型,默认为1000 + +## 加载训练 + +### 开始增量训练 + +使用MobileNetV2时,运行launch.py时传入参数`dataset_path`,`hccl_config_path`,增量训练时需要输入`pretrain_checkpoint_path`,若不输入,将会从初始化网络开始训练。`platform`、`nproc_per_node`、`batch_size`、`num_class`, `poch_size`, `image_height`, `image_width`等多种训练参数可以在运行`launch.py` 时传入或者使用默认训练参数。 + +运行`launch.py`时,在命令结尾可以使用`&> [log_file_path]`最终将标准输出与错误输出写入log文件。 增量训练成功开始训练,`./train/device/log.log` 中会持续写入每一个epoch的训练时间与Loss等信息。若未成功,上述log文件会写入报错信息。 + +```shell +python launch.py --platform="Ascend" --dataset_path='/data/ImageFolder/train' --nproc_per_node=8 +--pretrain_checkpoint_path="./checkpoint/mobilenetV2-200_625.ckpt" +--visible_devices='0,1,2,3,4,5,6,7' --hccl_config_path=[hccl_config_path] &> ./train.log +``` + +### 调整节点数量 + +运行Python文件时,设置`nproc_per_node`为每个节点Ascend AI处理器/GPU数量, `visible_device`为可使用的处理器编号,既Ascend AI处理器/GPU的ID,可以选择0-7中一个或多个设备ID。目前Ascend节点进程数量只能设置为1或者8。 + +- eg.1 + 使用8个Ascend AI处理器,设备ID为“0,1,2,3,4,5,6,7” + + ```shell + python launch.py --dataset_path=[dataset_path] --platform="Ascend" --nproc_per_node=8 --visible_devices="0,1,2,3,4,5,6,7", --hccl_config_path=[hccl_config_path] --pretrain_checkpoint_path="./pretrain_checkpoint/mobilenetV2-200_625.ckpt" + ``` + +- eg.2 + 使用1个Ascend AI处理器,设备ID为“0” + + ```shell + python launch.py --dataset_path=[dataset_path] --platform="Ascend" --nproc_per_node=1 --visible_devices="0", --hccl_config_path=[hccl_config_path] --pretrain_checkpoint_path="./checkpoint/mobilenetV2-200_625.ckpt" + ``` + +- eg.3 + 使用1个Ascend AI处理器,设备ID为“4” + + ```shell + python launch.py --dataset_path=[dataset_path] --platform="Ascend" --nproc_per_node=1 --visible_devices="4", --hccl_config_path=[hccl_config_path] --pretrain_checkpoint_path="./checkpoint/mobilenetV2-200_625.ckpt" + ``` + +### 训练结果 + +1. 查看运行结果 + + ```bash + cat ./train/device0/log0.log + ``` + + 输出结果如下 + + ```bash + train args: Namespace(batch_size=128, checkpoint_path=None, dataset_path='/dataset/ImageFolder/train/', epoch_size=200, image_height=224, image_width=224, keep_checkpoint_max=200, label_smooth=0.1, loss_scale=1024, lr=0.4, momentum=0.9, num_classes=1000, platform='Ascend', save_checkpoint=True, save_checkpoint_epochs=1, save_checkpoint_path='./checkpoint', warmup_epochs=4, weight_decay=4e-05) + cfg: {'ccl': 'hccl', 'platform': 'Ascend', 'device_id': 0, 'rank_id': 0, 'rank_size': 1, 'run_distribute': False} + parallel args: rank_id 0, device_id 0, rank_size 1 + =====WWW===== + Tensor shape:[[const vector][1]]Float32 + value:[ 0.0000000000e+00] + =====WWW===== + Tensor shape:[[const vector][1]]Float32 + value:[ 3.99680248e-05] + ... + ``` + +2. 查看保存的checkpoint 文件 + + ```bash + ls ./train/device0/checkpoint/*.ckpt + ``` + + 输出结果如下 + + ```bash + mobilenetV2-100_625.ckpt mobilenetV2-191_625.ckpt + mobilenetV2-101_625.ckpt mobilenetV2-192_625.ckpt + ... + ``` + +### 验证增量训练模型 + +使用验证集测试模型性能,必须输入dataset_path和pretrain_checkpoint_path,platform默认为“Ascend”,可以自行设置为GPU,最终将标准输出与错误输出写入infer.log 文件 + +```bash +python eval.py --platform="Ascend" --dataset_path="/data/ImageFolder/eval" +--checkpoint_path="./checkpoint/mobilenetV2-200_625.ckpt" &> ./infer.log +``` diff --git a/tutorials/source_zh_cn/mobilenetv2_finetune.md b/tutorials/source_zh_cn/mobilenetv2_finetune.md new file mode 100644 index 0000000000000000000000000000000000000000..851d53faf00b6891cd630f7f2e0bdee5f75b3499 --- /dev/null +++ b/tutorials/source_zh_cn/mobilenetv2_finetune.md @@ -0,0 +1,256 @@ +#
MobileNetV2 增量学习
+ +## 概述 + +  计算机视觉任务中,从头开始训练一个网络耗时巨大,需要大量计算能力。预训练模型选择的常见的OpenImage,Imagenet,VOC,COCO等公开大型数据集,规模通常几十万章,甚至超过百万张。大部分任务数据规模较小,训练网络模型时,如果不使用预训练模型,从头开始训练网络,需要消耗大量的时间与计算能力,模型容易陷入局部极小值和过拟合。因此大部分任务都会选择经过训练的预训练模型,在其上做增量学习。 +  MindSpore 是一个多元化的机器学习框架。既可以在手机等端侧和PC等设备上运行,也可以在云上的服务器集群上运行。下面教程以MobileNetV2为例,介绍如何在MindSpore框架中做增量学习。 + +## 任务描述及准备 + +### 环境配置 + +  若在华为云环境上运行,不需要安装Mindspore 框架和配置Ascend 芯片,可以跳过本小节。若在本地环境运行需要安装Mindspore 环境,配置Ascend 芯片/ GPU。 + +1. 安装Minspore 框架 +  在Euler、Ubuntu 或者Windows 系统上[安装MindSporo框架](https://www.mindspore.cn/install)。安装Mindspore框架时根据设备安装对应版本。 + +2. 配置Ascend 环境 +  以Ascend 910 AI处理器为例,1个8卡环境的json配置文件示例如下,本样例将该配置文件命名为rank_table_8pcs.json。2卡环境配置可以参考样例代码中的rank_table_2pcs.json文件。 + +```json +{ + "version": "1.0", + "server_count": "1", + "server_list": [ + { + "server_id": "10.155.111.140", + "device": [ + {"device_id": "0","device_ip": "192.1.27.6","rank_id": "0"}, + {"device_id": "1","device_ip": "192.2.27.6","rank_id": "1"}, + {"device_id": "2","device_ip": "192.3.27.6","rank_id": "2"}, + {"device_id": "3","device_ip": "192.4.27.6","rank_id": "3"}, + {"device_id": "4","device_ip": "192.1.27.7","rank_id": "4"}, + {"device_id": "5","device_ip": "192.2.27.7","rank_id": "5"}, + {"device_id": "6","device_ip": "192.3.27.7","rank_id": "6"}, + {"device_id": "7","device_ip": "192.4.27.7","rank_id": "7"}], + "host_nic_ip": "reserve" + } + ], + "status": "completed" +} +``` + +  使用Ascend 芯片时,在代码中,需要在调用Ascend 芯片开始训练或测试前,按照如下代码设置 + +```python +import os +from mindspore import context +from mindspore.communication.management import init + +if __name__ == "__main__": + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=int(os.environ["DEVICE_ID"])) + init() + ... +``` + +3. 配置GPU环境 +  使用GPU时,在代码中,需要在调用GPU开始训练或测试前,按照如下代码设置 + +```python +from mindspore import context +from mindspore.communication.management import init + +if __name__ == "__main__": + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + init("nccl") + ... +``` + +### Git clone 代码 + +  在Gitee 中下载[MindSpore开源项目仓库](https://gitee.com/mindspore/mindspore.git),进入./model_zoo/official/cv/mobilenetv2/ + +  或者执行以下两条命令使用git从Gitee 中克隆MindSpore项目代码,并且进入mobilenetv2文件夹 + + ```bash + git clone git@gitee.com:/mindspore.git + cd ./mindspore/model_zoo/official/cv/mobilenetv2 + ``` + +  训练时使用**launch.py**运行时会调用根据节点数量分配多张芯片同时启动多个进程运行train.py。不能直接使用~~train.py~~开始训练.验证与测试时使用 eval.py。 + +1. 代码结构 + + └─MobileNet + ├─checkpoint + ├─scr + ├─launch.py + ├─train.py + ├─eval.py + └─README.md + +### 准备预训练模型 + + 预训练模型目录: ./mindspore/model_zoo/official/cv/mobilenetv2/pretrained_checkpoint/[pretrain_checkpoint_file_name] + + ```python + cp -v [pretrain_checkpoint_file_path] ./mindspore/model_zoo/official/cv/mobilenetv2/pretrain_checkpoint/ + ``` + +### 数据准备 + + 准备ImageFolder格式管理的数据集, 并且在bash 命令行中运行时加入dataset_path 参数: + + 数据集结构 + └─ImageFolder + ├─train + │ ├─class1Folder + │ ├─class2Folder + │ └─...... + └─eval + ├─class1Folder + ├─class2Folder + └─...... + +## 加载预训练模型代码详解 + +  在增量学习时,需要加载预训练好的模型参数文件。 不同数据集和任务中特征提取层(卷积层)分布趋于一致,但是特征向量的组合(全连接层)不相同,分类数量(全连接层output_size)可能也不一致。 +  在模型加载时,首先根据Networt结构构建网络并且初始化网络全部参数(预训练模型全连接层class_num 为1000, 增量训练时通常class_num会改变),其次加载预训练模型中**卷积、BN等结构未变化的特征提取层**模型参数。在后续的增量训练中,冻结特征提取层参数,在训练过程中不再梯度回传更新参数,只更新全连接层参数。 + +```python +from mindspore.train.serialization import load_checkpoint, load_param_into_net + +def model_fine_tune(net, pretrained_path): + if pretrained_path is Node: + return + param_dict = load_checkpoint(pretrained_path) + #use pop func to discard the pretrained weight、bias and optimization params of full connection. + param_dict.pop('head.1.weight') + param_dict.pop('head.1.bias') + param_dict.pop('moments.head.1.weight') + param_dict.pop('moments.head.1.bias') + + #set all the params' requires_grad except the full connection laysers' false to freeze them in the fine tune process. + load_param_into_net(net, param_dict) + for param in net.get_parameters(): + if param.name not in ("head.1.weight", "head.1.bias"): + param.requires_grad(value=False) +``` + +  在构建Network时,Mindspore会根据python变量的命名自动隐式设置图结构中参数的命名。以MobileNetV2 为例,全局平均池化、全连接层、Dropout层共同构成了Head 层,因此全连接层的param.name是“head.1", 全连接层的权重和偏置分别为“head.1.weight”和“head.1.bias”。 + +```python +head = ([GlobalAvgPooling(), nn.Dense(self.out_channels, num_classes, has_bias=True)] if not has_dropout else [GlobalAvgPooling(), nn.Dropout(0.2), nn.Dense(self.out_channels, num_classes, has_bias=True)]) + +self.head = nn.SequentialCell(head) +self._initialize_weights() + +``` + +  若用户自定义Network,可以使用如下代码提取所有参数的命名,打印出来以方便查找全连接层参数命名并且使用**param_dict.pop(param.name)**来抛掉 。 + +```python +[pprint(param.name) for param in net.get_parameters()] +``` + +## 训练参数简介 +### 参数读取 + +  在训练与验证阶段,launch.py 与 eval.py 会调用./mobilenetv2/scr/args.py *_parse_args()从shell命令行中读取训练参数。 + +```python +python launch.py --dataset_path [dataset_path] --batch_size 128 --epoch_size 200 +``` + +**--platform** 设备,默认为Ascend, 可以设置为CPU 或 GPU +**--nproc_per_node** 每个节点(一台服务器/PC相当于一个节点)进程数量,建议 + 设置为机器上Ascend 芯片数量或CPU,GPU数量 +**--visible_devices** 字符串格式的的设备ID,训练将会根据visible_devices + 将程绑定到对应ID的设备上,多个设备ID之间使用','分隔,建议ID数量与进程数量相同 +**--hccl_config_path** platform选择 Ascend时需要配置Ascend的配置Json文件 +**--training_script** 训练脚本,默认"./train.py", 通常不需要传入新的值 +**--dataset_path** 训练与验证数据集地址,无默认值,用户训练/验证时必须输入 +**--pretrain_checkpoint_path** 增量训练时,需要传入checkpoint文件路径以加载预训练好的模型参数 +**--num_classes** 模型最终分类种类,整型,默认为1000 + +## 加载训练 + +### 开始增量训练 + +  使用MobileNetV2时,运行launch.py时传入参数**dataset_path, hccl_config_path**,增量训练时需要输入**pretrain_checkpoint_path**,若不输入,将会从初始化网络开始训练。**platform、nproc_per_node、batch_size、num_class, poch_size, image_height, image_width**等多种训练参数可以在运行launch.py 时传入或者使用默认训练参数。 + +  运行launch.py 时,最终可以使用**"&> [log_file_path]"**最终将标准输出与错误输出写入log 文件。 若成功开始增量训练,则/train/device\*/log\*.log 中会持续写入每一个epoch的训练时间与Loss等信息。若未成功,上述log文件会写入报错信息。 + +```shell +python launch.py --platform="Ascend" --dataset_path='/data/ImageFolder/train' --nproc_per_node=8 +--pretrain_checkpoint_path="./checkpoint/mobilenetV2-200_625.ckpt" +--visible_devices='0,1,2,3,4,5,6,7' --hccl_config_path="*.json" &> ./train.log +``` + +### 调整节点数量 + +  运行python文件时,设置nproc_per_node 为节点数量, visible_device 为节点编号,即GPU/Ascend设备ID,通常选择0-7中一个或多个设备ID。 目前CPU不需要设置节点进程数量与设备ID, Ascend节点进程数量只能设置为1或者8。 + +- eg.1 + + ```shell + python launch.py --dataset_path=[dataset_path] --platform="Ascend" --nproc_per_node=8 --visible_devices="0,1,2,3,4,5,6,7", --hccl_config_path=[hccl_config_path] --pretrain_checkpoint_path="./pretrain_checkpoint/mobilenetV2-200_625.ckpt" + ``` + +- eg.2 + + ```shell + python launch.py --dataset_path=[dataset_path] --platform="Ascend" --nproc_per_node=1 --visible_devices="0", --hccl_config_path=[hccl_config_path] --pretrain_checkpoint_path="./checkpoint/mobilenetV2-200_625.ckpt" + ``` + +- eg.3 + + ```shell + python launch.py --dataset_path=[dataset_path] --platform="Ascend" --nproc_per_node=1 --visible_devices="4", --hccl_config_path=[hccl_config_path] --pretrain_checkpoint_path="./checkpoint/mobilenetV2-200_625.ckpt" + ``` + +### 训练结果 + +1. 查看运行结果 + + ```bash + cat ./train/device0/log0.log + ``` + + 输出结果如下 + + ```bash + train args: Namespace(batch_size=128, checkpoint_path=None, dataset_path='/store/dataset/ImageNet_Original/train/', epoch_size=200, image_height=224, image_width=224, keep_checkpoint_max=200, label_smooth=0.1, loss_scale=1024, lr=0.4, momentum=0.9, num_classes=1000, platform='Ascend', save_checkpoint=True, save_checkpoint_epochs=1, save_checkpoint_path='./checkpoint', warmup_epochs=4, weight_decay=4e-05) + cfg: {'ccl': 'hccl', 'platform': 'Ascend', 'device_id': 0, 'rank_id': 0, 'rank_size': 1, 'run_distribute': False} + parallel args: rank_id 0, device_id 0, rank_size 1 + =====WWW===== + Tensor shape:[[const vector][1]]Float32 + value:[ 0.0000000000e+00] + =====WWW===== + Tensor shape:[[const vector][1]]Float32 + value:[ 3.99680248e-05] + ... + ``` + +2. 查看保存的checkpoint 文件 + + ```bash + ls ./train/device0/checkpoint/*.ckpt + ``` + + 输出结果如下 + + ```bash + mobilenetV2-100_625.ckpt mobilenetV2-191_625.ckpt + mobilenetV2-101_625.ckpt mobilenetV2-192_625.ckpt + ... + ``` + +### 训练结束,验证模型 + +  使用验证集测试模型性能, 必须输入 dataset_path 和pretrain_checkpoint_path ,platform默认为“Ascend”,可以自行设置为 GPU/CPU, 最终将标准输出与错误输出写入infer.log 文件 + + ```bash + python eval.py --platform="Ascend" --dataset_path='/data/ImageFolder/eval' + --pretrain_checkpoint_path="./pretrain_checkpoint/mobilenetV2-200_625.ckpt" &> ./infer.log + ``` \ No newline at end of file diff --git a/tutorials/tutorial_code/mobilenetv2_finetune/Readme.md b/tutorials/tutorial_code/mobilenetv2_finetune/Readme.md new file mode 100644 index 0000000000000000000000000000000000000000..f7021170fdfe790a552b3e08668307d02daa6dd8 --- /dev/null +++ b/tutorials/tutorial_code/mobilenetv2_finetune/Readme.md @@ -0,0 +1,196 @@ +# Contents + +- [MobileNetV2 Description](#mobilenetv2-description) +- [Model Architecture](#model-architecture) +- [Dataset](#dataset) +- [Features](#features) + - [Mixed Precision](#mixed-precision) +- [Environment Requirements](#environment-requirements) +- [Script Description](#script-description) + - [Script and Sample Code](#script-and-sample-code) + - [Training Process](#training-process) + - [Evaluation Process](#evaluation-process) + - [Evaluation](#evaluation) +- [Model Description](#model-description) + - [Performance](#performance) + - [Training Performance](#evaluation-performance) + - [Inference Performance](#evaluation-performance) +- [Description of Random Situation](#description-of-random-situation) +- [ModelZoo Homepage](#modelzoo-homepage) + +# [MobileNetV2 Description](#contents) + + +MobileNetV2 is tuned to mobile phone CPUs through a combination of hardware- aware network architecture search (NAS) complemented by the NetAdapt algorithm and then subsequently improved through novel architecture advances.Nov 20, 2019. + +[Paper](https://arxiv.org/pdf/1905.02244) Howard, Andrew, Mark Sandler, Grace Chu, Liang-Chieh Chen, Bo Chen, Mingxing Tan, Weijun Wang et al. "Searching for MobileNetV2." In Proceedings of the IEEE International Conference on Computer Vision, pp. 1314-1324. 2019. + +# [Model architecture](#contents) + +The overall network architecture of MobileNetV2 is show below: + +[Link](https://arxiv.org/pdf/1905.02244) + +# [Dataset](#contents) + +Dataset used: [imagenet](http://www.image-net.org/) + +- Dataset size: ~125G, 1.2W colorful images in 1000 classes + - Train: 120G, 1.2W images + - Test: 5G, 50000 images +- Data format: RGB images. + - Note: Data will be processed in src/dataset.py + + +# [Features](#contents) + +## [Mixed Precision(Ascend)](#contents) + +The [mixed precision](https://www.mindspore.cn/tutorial/zh-CN/master/advanced_use/mixed_precision.html) training method accelerates the deep learning neural network training process by using both the single-precision and half-precision data formats, and maintains the network precision achieved by the single-precision training at the same time. Mixed precision training can accelerate the computation process, reduce memory usage, and enable a larger model or batch size to be trained on specific hardware. +For FP16 operators, if the input data type is FP32, the backend of MindSpore will automatically handle it with reduced precision. Users could check the reduced-precision operators by enabling INFO log and then searching ‘reduce precision’. + +# [Environment Requirements](#contents) + +- Hardware(Ascend/GPU) + - Prepare hardware environment with Ascend or GPU processor. If you want to try Ascend , please send the [application form](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/file/other/Ascend%20Model%20Zoo%E4%BD%93%E9%AA%8C%E8%B5%84%E6%BA%90%E7%94%B3%E8%AF%B7%E8%A1%A8.docx) to ascend@huawei.com. Once approved, you can get the resources. +- Framework + - [MindSpore](http://10.90.67.50/mindspore/archive/20200506/OpenSource/me_vm_x86/) +- For more information, please check the resources below: + - [MindSpore tutorials](https://www.mindspore.cn/tutorial/zh-CN/master/index.html) + - [MindSpore API](https://www.mindspore.cn/api/zh-CN/master/index.html) + + +# [Script description](#contents) + +## [Script and sample code](#contents) + +```python +├── MobileNetV2 + ├── Readme.md # descriptions about MobileNetV2 + ├── scripts + │ ├──run_train.sh # shell script for train + │ ├──run_eval.sh # shell script for evaluation + ├── src + │ ├──config.py # parameter configuration + │ ├──dataset.py # creating dataset + │ ├──launch.py # start python script + │ ├──lr_generator.py # learning rate config + │ ├──mobilenetV2.py # MobileNetV2 architecture + ├── train.py # training script + ├── eval.py # evaluation script +``` + +## [Training process](#contents) + +### Usage + + +You can start training using python or shell scripts. The usage of shell scripts as follows: + +- Ascend: sh run_train.sh Ascend [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [RANK_TABLE_FILE] [DATASET_PATH] [CKPT_PATH] +- GPU: sh run_trian.sh GPU [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH] + +### Launch + +``` +# training example + python: + Ascend: python train.py --dataset_path ~/imagenet/train/ --device_targe Ascend + GPU: python train.py --dataset_path ~/imagenet/train/ --device_targe GPU + + shell: + Ascend: sh run_train.sh Ascend 8 0,1,2,3,4,5,6,7 hccl_config.json ~/imagenet/train/ mobilenet_199.ckpt + GPU: sh run_train.sh GPU 8 0,1,2,3,4,5,6,7 ~/imagenet/train/ +``` + +### Result + +Training result will be stored in the example path. Checkpoints will be stored at `. /checkpoint` by default, and training log will be redirected to `./train/train.log` like followings. + +``` +epoch: [ 0/200], step:[ 624/ 625], loss:[5.258/5.258], time:[140412.236], lr:[0.100] +epoch time: 140522.500, per step time: 224.836, avg loss: 5.258 +epoch: [ 1/200], step:[ 624/ 625], loss:[3.917/3.917], time:[138221.250], lr:[0.200] +epoch time: 138331.250, per step time: 221.330, avg loss: 3.917 +``` + +## [Eval process](#contents) + +### Usage + +You can start training using python or shell scripts. The usage of shell scripts as follows: + +- Ascend: sh run_infer.sh Ascend [DATASET_PATH] [CHECKPOINT_PATH] +- GPU: sh run_infer.sh GPU [DATASET_PATH] [CHECKPOINT_PATH] + +### Launch + +``` +# infer example + python: + Ascend: python eval.py --dataset_path ~/imagenet/val/ --checkpoint_path mobilenet_199.ckpt --device_targe Ascend + GPU: python eval.py --dataset_path ~/imagenet/val/ --checkpoint_path mobilenet_199.ckpt --device_targe GPU + + shell: + Ascend: sh run_infer.sh Ascend ~/imagenet/val/ ~/train/mobilenet-200_625.ckpt + GPU: sh run_infer.sh GPU ~/imagenet/val/ ~/train/mobilenet-200_625.ckpt +``` + +> checkpoint can be produced in training process. + +### Result + +Inference result will be stored in the example path, you can find result like the followings in `val.log`. + +``` +result: {'acc': 0.71976314102564111} ckpt=/path/to/checkpoint/mobilenet-200_625.ckpt +``` + +# [Model description](#contents) + +## [Performance](#contents) + +### Training Performance + +| Parameters | MobilenetV2 | | +| -------------------------- | ---------------------------------------------------------- | ------------------------- | +| Model Version | | large | +| Resource | Ascend 910, cpu:2.60GHz 56cores, memory:314G | NV SMX2 V100-32G | +| uploaded Date | 05/06/2020 | 05/06/2020 | +| MindSpore Version | 0.3.0 | 0.3.0 | +| Dataset | ImageNet | ImageNet | +| Training Parameters | src/config.py | src/config.py | +| Optimizer | Momentum | Momentum | +| Loss Function | SoftmaxCrossEntropy | SoftmaxCrossEntropy | +| outputs | | | +| Loss | | 1.913 | +| Accuracy | | ACC1[77.09%] ACC5[92.57%] | +| Total time | | | +| Params (M) | | | +| Checkpoint for Fine tuning | | | +| Model for inference | | | + +#### Inference Performance + +| Parameters | | | | +| -------------------------- | ----------------------------- | ------------------------- | -------------------- | +| Model Version | V1 | | | +| Resource | Huawei 910 | NV SMX2 V100-32G | Huawei 310 | +| uploaded Date | 05/06/2020 | 05/22/2020 | | +| MindSpore Version | 0.2.0 | 0.2.0 | 0.2.0 | +| Dataset | ImageNet, 1.2W | ImageNet, 1.2W | ImageNet, 1.2W | +| batch_size | | 130(8P) | | +| outputs | | | | +| Accuracy | | ACC1[72.07%] ACC5[90.90%] | | +| Speed | | | | +| Total time | | | | +| Model for inference | | | | + +# [Description of Random Situation](#contents) + +In dataset.py, we set the seed inside “create_dataset" function. We also use random seed in train.py. + +# [ModelZoo Homepage](#contents) + +Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo). + \ No newline at end of file diff --git a/tutorials/tutorial_code/mobilenetv2_finetune/eval.py b/tutorials/tutorial_code/mobilenetv2_finetune/eval.py new file mode 100644 index 0000000000000000000000000000000000000000..9d65ec560c768857a2d78539e923985e93305dda --- /dev/null +++ b/tutorials/tutorial_code/mobilenetv2_finetune/eval.py @@ -0,0 +1,62 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +eval. +""" +import os +import sys +from mindspore import context +from mindspore import nn +from mindspore.train.model import Model +from mindspore.train.serialization import load_checkpoint, load_param_into_net +from mindspore.common import dtype as mstype +from src.dataset import create_dataset +from src.config import set_config +from src.args import eval_parse_args +from src.utils import set_context, switch_precision +from src.mobilenetV2 import mobilenet_v2 + + + +if __name__ == '__main__': + + os.system("export DEVICE_ID=0;export RANK_ID=0;export RANK_SIZE=1") + if os.path.exists("./eval"): + os.system("rm -rf ./eval;mkdir eval") + + args_opt = eval_parse_args() + config = set_config(args_opt) + set_context(args_opt, config) + net = mobilenet_v2(num_classes=args_opt.num_classes, platform=args_opt.platform) + + loss = nn.SoftmaxCrossEntropyWithLogits( + is_grad=False, sparse=True, reduction='mean') + + if args_opt.platform == "Ascend": + switch_precision(net, mstype.float16) + + dataset = create_dataset(args = args_opt, do_train=False, repeat_num=1) + + step_size = dataset.get_dataset_size() + + assert args_opt.checkpoint_path + param_dict = load_checkpoint(args_opt.checkpoint_path) + load_param_into_net(net, param_dict) + + net.set_train(False) + + model = Model(net, loss_fn=loss, metrics={'acc'}) + res = model.eval(dataset) + print("result:", res, "ckpt=", args_opt.checkpoint_path) diff --git a/tutorials/tutorial_code/mobilenetv2_finetune/launch.py b/tutorials/tutorial_code/mobilenetv2_finetune/launch.py new file mode 100644 index 0000000000000000000000000000000000000000..b786d09a8074213fba420e0bd2b53aa5aa66addf --- /dev/null +++ b/tutorials/tutorial_code/mobilenetv2_finetune/launch.py @@ -0,0 +1,77 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""launch train script""" +import os +import sys +import subprocess +import shutil +from src.args import launch_parse_args + +def main(): + print("start", __file__) + args = launch_parse_args() + print(args) + os.system("rm -rf ./train;mkdir ./train") + visible_devices = args.visible_devices.split(',') + device_num = len(visible_devices) + + if args.platform == "Ascend": + assert os.path.isfile(args.hccl_config_path) + os.system("export MINDSPORE_HCCL_CONFIG_PATH={};export RANK_TABLE_FILE={}".format(args.hccl_config_path, args.hccl_config_path)) + + elif args.platform == "GPU": + os.system("export CUDA_VISIBLE_DEVICES={};mpirun -n {} --allow-run-as-root".format(args.visible_device, device_num)) + + assert os.path.isfile(args.training_script) + assert len(visible_devices) >= args.nproc_per_node + print('visible_devices:{}'.format(visible_devices)) + + # spawn the processes + processes = [] + cmds = [] + log_files = [] + env = os.environ.copy() + env['RANK_SIZE'] = str(args.nproc_per_node) + cur_path = os.getcwd() + training_script = os.path.join(cur_path, args.training_script) + for rank_id in range(0, args.nproc_per_node): + os.chdir(cur_path) + device_id = visible_devices[rank_id] + device_dir = os.path.join(cur_path, 'train/device{}'.format(rank_id)) + env['RANK_ID'] = str(rank_id) + env['DEVICE_ID'] = str(device_id) + if os.path.exists(device_dir): + shutil.rmtree(device_dir) + os.mkdir(device_dir) + os.chdir(device_dir) + cmd = [sys.executable, '-u'] + cmd.append(training_script) + cmd.extend(args.training_script_args) + log_file = open('{dir}/log{id}.log'.format(dir=device_dir, id=rank_id), 'w') + process = subprocess.Popen(cmd, stdout=log_file, stderr=log_file, env=env) + processes.append(process) + cmds.append(cmd) + log_files.append(log_file) + + for process, cmd, log_file in zip(processes, cmds, log_files): + process.wait() + # process.join() + if process.returncode != 0: + raise subprocess.CalledProcessError(returncode=process, cmd=cmd) + log_file.close() + + +if __name__ == "__main__": + main() diff --git a/tutorials/tutorial_code/mobilenetv2_finetune/src/args.py b/tutorials/tutorial_code/mobilenetv2_finetune/src/args.py new file mode 100644 index 0000000000000000000000000000000000000000..ed31c8a855c8bec16d687ea3e22becc262622604 --- /dev/null +++ b/tutorials/tutorial_code/mobilenetv2_finetune/src/args.py @@ -0,0 +1,73 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import argparse + +def launch_parse_args(): + launch_parser = argparse.ArgumentParser(description="mindspore distributed training launch helper utilty that will spawn up multiple distributed processes") + launch_parser.add_argument('--platform', type=str, default="Ascend", choices=("Ascend", "GPU", "CPU"), help='run platform, only support GPU, CPU and Ascend') + launch_parser.add_argument("--nproc_per_node", type=int, default=1, choices=(0,1,2,3,4,5,6,7), help="The number of processes to launch on each node, for D training, this is recommended to be set to the number of D in your system so that each process can be bound to a single D.") + launch_parser.add_argument("--visible_devices", type=str, default="0,1,2,3,4,5,6,7", help="will use the visible devices sequentially") + launch_parser.add_argument("--hccl_config_path", type=str, default=None, help="The hccl config file to set ascend for D training, this is recommended to be set when use Ascend platform") + launch_parser.add_argument("--training_script", type=str, default="./train.py", help="The full path to the single D training program/script to be launched in parallel, followed by all the arguments for the training script") + + launch_args, unknown = launch_parser.parse_known_args() + launch_args.training_script_args = unknown + launch_args.training_script_args += ["--platform" ,launch_args.platform] + return launch_args + +def train_parse_args(): + train_parser = argparse.ArgumentParser(description='Image classification') + train_parser.add_argument('--dataset_path', type=str, required=True, help='Dataset path') + train_parser.add_argument('--platform', type=str, default="Ascend", choices=("Ascend", "GPU", "CPU"), help='run platform, only support GPU, CPU and Ascend') + train_parser.add_argument('--checkpoint_path', type=str, default=None, help='the path of checkpoint file to eval') + train_parser.add_argument('--num_classes', type=int, default=1000, help='the classes number of image classification') + train_parser.add_argument('--image_height', type=int, default=224, help='the height of image') + train_parser.add_argument('--image_width', type=int, default=224, help='the width of image') + train_parser.add_argument('--batch_size', type=int, default=128, help='the image num of every step in training process') + train_parser.add_argument('--epoch_size', type=int, default=200, help='the count of whole dataset to be trained ') + train_parser.add_argument('--momentum', type=float, default=0.9, help='optimizer args') + train_parser.add_argument('--weight_decay', type=float, default=4e-5, help='trainable params decay weight') + train_parser.add_argument('--label_smooth', type=float, default=0.1, help='label smooth used in loss computation') + train_parser.add_argument('--loss_scale', type=int, default=1024, help='run platform') + train_parser.add_argument('--save_checkpoint', type=bool, default=True, help='if to save checkpoint file or not') + train_parser.add_argument('--save_checkpoint_epochs', type=int, default=1, help='how many epochs to save ckpoint file') + train_parser.add_argument('--keep_checkpoint_max', type=int, default=200, help='the max of checkpoint count to save') + train_parser.add_argument('--save_checkpoint_path', type=str, default="./checkpoint", help='file path to save checkpoint file') + + train_args = train_parser.parse_args() + + if train_args.platform == "Ascend": + train_parser.add_argument('--lr', type=float, default=0.4, help='learning rate') + train_parser.add_argument('--warmup_epochs', type=float, default=4, help='warmup_epochs after load pretrained model') + + elif train_args.platform in ("CPU", "GPU"): + train_parser.add_argument('--lr', type=float, default=0.8, help='learning rate') + train_parser.add_argument('--warmup_epochs', type=float, default=0, help='warmup_epochs after load pretrained model') + + train_args = train_parser.parse_args() + return train_args + +def eval_parse_args(): + eval_parser = argparse.ArgumentParser(description='Image classification') + eval_parser.add_argument('--platform', type=str, default="Ascend", choices=("Ascend", "GPU", "CPU"), help='run platform, only support GPU, CPU and Ascend') + eval_parser.add_argument('--dataset_path', type=str, required=True, help='Dataset path') + eval_parser.add_argument('--checkpoint_path', type=str, default=None, help='the path of checkpoint file to eval') + eval_parser.add_argument('--num_classes', type=int, default=1000, help='the classes number of image classification') + eval_parser.add_argument('--image_height', type=int, default=224, help='the height of image') + eval_parser.add_argument('--image_width', type=int, default=224, help='the width of image') + eval_parser.add_argument('--batch_size', type=int, default=150, help='the image num of every step in training process') + eval_args = eval_parser.parse_args() + return eval_args \ No newline at end of file diff --git a/tutorials/tutorial_code/mobilenetv2_finetune/src/config.py b/tutorials/tutorial_code/mobilenetv2_finetune/src/config.py new file mode 100644 index 0000000000000000000000000000000000000000..55d735a69f55f13f549cacf301f12aec23c2892b --- /dev/null +++ b/tutorials/tutorial_code/mobilenetv2_finetune/src/config.py @@ -0,0 +1,48 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +network config setting, will be used in train.py and eval.py +""" +import os + +from easydict import EasyDict as ed + +def set_config(args): + if args.platform == "GPU": + config = ed({ + "ccl":"nccl", + "platform": args.platform, + }) + + elif args.platform == "CPU": + config = ed({ + "platform": args.platform, + + }) + + elif args.platform == "Ascend": + config = ed({ + "ccl":"hccl", + "platform": args.platform, + "device_id": int(os.getenv('DEVICE_ID')), + "rank_id": int(os.getenv('RANK_ID')), + "rank_size": int(os.getenv('RANK_SIZE')), + "run_distribute": int(os.getenv('RANK_SIZE')) > 1. + }) + else: + raise ValueError("Unsupport platform.") + + return config + diff --git a/tutorials/tutorial_code/mobilenetv2_finetune/src/dataset.py b/tutorials/tutorial_code/mobilenetv2_finetune/src/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..47e1f3fd0efd871220771f85415fa089c78710a0 --- /dev/null +++ b/tutorials/tutorial_code/mobilenetv2_finetune/src/dataset.py @@ -0,0 +1,92 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +create train or eval dataset. +""" +import os +import mindspore.common.dtype as mstype +import mindspore.dataset.engine as de +import mindspore.dataset.transforms.vision.c_transforms as C +import mindspore.dataset.transforms.c_transforms as C2 + +def create_dataset(args, do_train, repeat_num=1): + """ + create a train or eval dataset + + Args: + dataset_path(string): the path of dataset. + do_train(bool): whether dataset is used for train or eval. + repeat_num(int): the repeat times of dataset. Default: 1. + batch_size(int): the batch size of dataset. Default: 32. + + Returns: + dataset + """ + if args.platform == "Ascend": + rank_size = int(os.getenv("RANK_SIZE")) + rank_id = int(os.getenv("RANK_ID")) + if rank_size == 1: + ds = de.ImageFolderDatasetV2(args.dataset_path, num_parallel_workers=8, shuffle=True) + else: + ds = de.ImageFolderDatasetV2(args.dataset_path, num_parallel_workers=8, shuffle=True, + num_shards=rank_size, shard_id=rank_id) + elif args.platform == "GPU": + if do_train: + from mindspore.communication.management import get_rank, get_group_size + ds = de.ImageFolderDatasetV2(args.dataset_path, num_parallel_workers=8, shuffle=True, + num_shards=get_group_size(), shard_id=get_rank()) + else: + ds = de.ImageFolderDatasetV2(args.dataset_path, num_parallel_workers=8, shuffle=True) + #TODO add the part of CPU dataset + elif args.platform == "CPU": + ds = de.ImageFolderDatasetV2(args.dataset_path, num_parallel_workers=8, shuffle=True) + else: + raise ValueError("Unsupport platform.") + + resize_height = args.image_height + resize_width = args.image_width + buffer_size = 1000 + + # define map operations + decode_op = C.Decode() + resize_crop_op = C.RandomCropDecodeResize(resize_height, scale=(0.08, 1.0), ratio=(0.75, 1.333)) + horizontal_flip_op = C.RandomHorizontalFlip(prob=0.5) + + resize_op = C.Resize((256, 256)) + center_crop = C.CenterCrop(resize_width) + rescale_op = C.RandomColorAdjust(brightness=0.4, contrast=0.4, saturation=0.4) + normalize_op = C.Normalize(mean=[0.485*255, 0.456*255, 0.406*255], std=[0.229*255, 0.224*255, 0.225*255]) + change_swap_op = C.HWC2CHW() + + if do_train: + trans = [resize_crop_op, horizontal_flip_op, rescale_op, normalize_op, change_swap_op] + else: + trans = [decode_op, resize_op, center_crop, normalize_op, change_swap_op] + + type_cast_op = C2.TypeCast(mstype.int32) + + ds = ds.map(input_columns="image", operations=trans, num_parallel_workers=8) + ds = ds.map(input_columns="label", operations=type_cast_op, num_parallel_workers=8) + + # apply shuffle operations + ds = ds.shuffle(buffer_size=buffer_size) + + # apply batch operations + ds = ds.batch(args.batch_size, drop_remainder=True) + + # apply dataset repeat operation + ds = ds.repeat(repeat_num) + + return ds \ No newline at end of file diff --git a/tutorials/tutorial_code/mobilenetv2_finetune/src/launch.py b/tutorials/tutorial_code/mobilenetv2_finetune/src/launch.py new file mode 100644 index 0000000000000000000000000000000000000000..f5c97b0bd7060c4c6c11219ef8f9dfafb85f1f58 --- /dev/null +++ b/tutorials/tutorial_code/mobilenetv2_finetune/src/launch.py @@ -0,0 +1,97 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""launch train script""" +import os +import sys +import subprocess +import shutil +from argparse import ArgumentParser + +def parse_args(): + """ + parse args . + + Args: + + Returns: + args. + + Examples: + >>> parse_args() + """ + parser = ArgumentParser(description="mindspore distributed training launch " + "helper utilty that will spawn up " + "multiple distributed processes") + parser.add_argument("--nproc_per_node", type=int, default=1, + help="The number of processes to launch on each node, " + "for D training, this is recommended to be set " + "to the number of D in your system so that " + "each process can be bound to a single D.") + parser.add_argument("--visible_devices", type=str, default="0,1,2,3,4,5,6,7", + help="will use the visible devices sequentially") + parser.add_argument("--training_script", type=str, + help="The full path to the single D training " + "program/script to be launched in parallel, " + "followed by all the arguments for the " + "training script") + # rest from the training program + args, unknown = parser.parse_known_args() + args.training_script_args = unknown + return args + + +def main(): + print("start", __file__) + args = parse_args() + print(args) + visible_devices = args.visible_devices.split(',') + assert os.path.isfile(args.training_script) + assert len(visible_devices) >= args.nproc_per_node + print('visible_devices:{}'.format(visible_devices)) + + # spawn the processes + processes = [] + cmds = [] + log_files = [] + env = os.environ.copy() + env['RANK_SIZE'] = str(args.nproc_per_node) + cur_path = os.getcwd() + for rank_id in range(0, args.nproc_per_node): + os.chdir(cur_path) + device_id = visible_devices[rank_id] + device_dir = os.path.join(cur_path, 'device{}'.format(rank_id)) + env['RANK_ID'] = str(rank_id) + env['DEVICE_ID'] = str(device_id) + if os.path.exists(device_dir): + shutil.rmtree(device_dir) + os.mkdir(device_dir) + os.chdir(device_dir) + cmd = [sys.executable, '-u'] + cmd.append(args.training_script) + cmd.extend(args.training_script_args) + log_file = open('{dir}/log{id}.log'.format(dir=device_dir, id=rank_id), 'w') + process = subprocess.Popen(cmd, stdout=log_file, stderr=log_file, env=env) + processes.append(process) + cmds.append(cmd) + log_files.append(log_file) + for process, cmd, log_file in zip(processes, cmds, log_files): + process.wait() + if process.returncode != 0: + raise subprocess.CalledProcessError(returncode=process, cmd=cmd) + log_file.close() + + +if __name__ == "__main__": + main() diff --git a/tutorials/tutorial_code/mobilenetv2_finetune/src/lr_generator.py b/tutorials/tutorial_code/mobilenetv2_finetune/src/lr_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..68bbfe315847d4d291d3d5f30df1f8f94de1dc27 --- /dev/null +++ b/tutorials/tutorial_code/mobilenetv2_finetune/src/lr_generator.py @@ -0,0 +1,54 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""learning rate generator""" +import math +import numpy as np + + +def get_lr(global_step, lr_init, lr_end, lr_max, warmup_epochs, total_epochs, steps_per_epoch): + """ + generate learning rate array + + Args: + global_step(int): total steps of the training + lr_init(float): init learning rate + lr_end(float): end learning rate + lr_max(float): max learning rate + warmup_epochs(int): number of warmup epochs + total_epochs(int): total epoch of training + steps_per_epoch(int): steps of one epoch + + Returns: + np.array, learning rate array + """ + lr_each_step = [] + total_steps = steps_per_epoch * total_epochs + warmup_steps = steps_per_epoch * warmup_epochs + for i in range(total_steps): + if i < warmup_steps: + lr = lr_init + (lr_max - lr_init) * i / warmup_steps + else: + lr = lr_end + \ + (lr_max - lr_end) * \ + (1. + math.cos(math.pi * (i - warmup_steps) / (total_steps - warmup_steps))) / 2. + if lr < 0.0: + lr = 0.0 + lr_each_step.append(lr) + + current_step = global_step + lr_each_step = np.array(lr_each_step).astype(np.float32) + learning_rate = lr_each_step[current_step:] + + return learning_rate diff --git a/tutorials/tutorial_code/mobilenetv2_finetune/src/mobilenetV2.py b/tutorials/tutorial_code/mobilenetv2_finetune/src/mobilenetV2.py new file mode 100644 index 0000000000000000000000000000000000000000..6d0b6f38e0f803b1bdbd01c3dc1782a336a2811c --- /dev/null +++ b/tutorials/tutorial_code/mobilenetv2_finetune/src/mobilenetV2.py @@ -0,0 +1,292 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""MobileNetV2 model define""" +import numpy as np +import mindspore.nn as nn +from mindspore.ops import operations as P +from mindspore.ops.operations import TensorAdd +from mindspore import Parameter, Tensor +from mindspore.common.initializer import initializer + +__all__ = ['mobilenet_v2'] + + +def _make_divisible(v, divisor, min_value=None): + if min_value is None: + min_value = divisor + new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) + # Make sure that round down does not go down by more than 10%. + if new_v < 0.9 * v: + new_v += divisor + return new_v + + +class GlobalAvgPooling(nn.Cell): + """ + Global avg pooling definition. + + Args: + + Returns: + Tensor, output tensor. + + Examples: + >>> GlobalAvgPooling() + """ + + def __init__(self): + super(GlobalAvgPooling, self).__init__() + self.mean = P.ReduceMean(keep_dims=False) + + def construct(self, x): + x = self.mean(x, (2, 3)) + return x + + +class DepthwiseConv(nn.Cell): + """ + Depthwise Convolution warpper definition. + + Args: + in_planes (int): Input channel. + kernel_size (int): Input kernel size. + stride (int): Stride size. + pad_mode (str): pad mode in (pad, same, valid) + channel_multiplier (int): Output channel multiplier + has_bias (bool): has bias or not + + Returns: + Tensor, output tensor. + + Examples: + >>> DepthwiseConv(16, 3, 1, 'pad', 1, channel_multiplier=1) + """ + + def __init__(self, in_planes, kernel_size, stride, pad_mode, pad, channel_multiplier=1, has_bias=False): + super(DepthwiseConv, self).__init__() + self.has_bias = has_bias + self.in_channels = in_planes + self.channel_multiplier = channel_multiplier + self.out_channels = in_planes * channel_multiplier + self.kernel_size = (kernel_size, kernel_size) + self.depthwise_conv = P.DepthwiseConv2dNative(channel_multiplier=channel_multiplier, + kernel_size=self.kernel_size, + stride=stride, pad_mode=pad_mode, pad=pad) + self.bias_add = P.BiasAdd() + weight_shape = [channel_multiplier, in_planes, *self.kernel_size] + self.weight = Parameter(initializer('ones', weight_shape), name='weight') + + if has_bias: + bias_shape = [channel_multiplier * in_planes] + self.bias = Parameter(initializer('zeros', bias_shape), name='bias') + else: + self.bias = None + + def construct(self, x): + output = self.depthwise_conv(x, self.weight) + if self.has_bias: + output = self.bias_add(output, self.bias) + return output + + +class ConvBNReLU(nn.Cell): + """ + Convolution/Depthwise fused with Batchnorm and ReLU block definition. + + Args: + in_planes (int): Input channel. + out_planes (int): Output channel. + kernel_size (int): Input kernel size. + stride (int): Stride size for the first convolutional layer. Default: 1. + groups (int): channel group. Convolution is 1 while Depthiwse is input channel. Default: 1. + + Returns: + Tensor, output tensor. + + Examples: + >>> ConvBNReLU(16, 256, kernel_size=1, stride=1, groups=1) + """ + + def __init__(self, device_target, in_planes, out_planes, kernel_size=3, stride=1, groups=1): + super(ConvBNReLU, self).__init__() + padding = (kernel_size - 1) // 2 + if groups == 1: + conv = nn.Conv2d(in_planes, out_planes, kernel_size, stride, pad_mode='pad', padding=padding) + else: + if device_target == "Ascend": + conv = DepthwiseConv(in_planes, kernel_size, stride, pad_mode='pad', pad=padding) + elif device_target == "GPU": + conv = nn.Conv2d(in_planes, out_planes, kernel_size, stride, + group=in_planes, pad_mode='pad', padding=padding) + + layers = [conv, nn.BatchNorm2d(out_planes), nn.ReLU6()] + self.features = nn.SequentialCell(layers) + + def construct(self, x): + output = self.features(x) + return output + + +class InvertedResidual(nn.Cell): + """ + Mobilenetv2 residual block definition. + + Args: + inp (int): Input channel. + oup (int): Output channel. + stride (int): Stride size for the first convolutional layer. Default: 1. + expand_ratio (int): expand ration of input channel + + Returns: + Tensor, output tensor. + + Examples: + >>> ResidualBlock(3, 256, 1, 1) + """ + + def __init__(self, device_target, inp, oup, stride, expand_ratio): + super(InvertedResidual, self).__init__() + assert stride in [1, 2] + + hidden_dim = int(round(inp * expand_ratio)) + self.use_res_connect = stride == 1 and inp == oup + + layers = [] + if expand_ratio != 1: + layers.append(ConvBNReLU(device_target, inp, hidden_dim, kernel_size=1)) + layers.extend([ + # dw + ConvBNReLU(device_target, hidden_dim, hidden_dim, + stride=stride, groups=hidden_dim), + # pw-linear + nn.Conv2d(hidden_dim, oup, kernel_size=1, + stride=1, has_bias=False), + nn.BatchNorm2d(oup), + ]) + self.conv = nn.SequentialCell(layers) + self.add = TensorAdd() + self.cast = P.Cast() + + def construct(self, x): + identity = x + x = self.conv(x) + if self.use_res_connect: + return self.add(identity, x) + return x + + +class MobileNetV2(nn.Cell): + """ + MobileNetV2 architecture. + + Args: + class_num (Cell): number of classes. + width_mult (int): Channels multiplier for round to 8/16 and others. Default is 1. + has_dropout (bool): Is dropout used. Default is false + inverted_residual_setting (list): Inverted residual settings. Default is None + round_nearest (list): Channel round to . Default is 8 + Returns: + Tensor, output tensor. + + Examples: + >>> MobileNetV2(num_classes=1000) + """ + + def __init__(self, device_target, num_classes=1000, width_mult=1., + has_dropout=False, inverted_residual_setting=None, round_nearest=8): + super(MobileNetV2, self).__init__() + block = InvertedResidual + input_channel = 32 + last_channel = 1280 + # setting of inverted residual blocks + self.cfgs = inverted_residual_setting + if inverted_residual_setting is None: + self.cfgs = [ + # t, c, n, s + [1, 16, 1, 1], + [6, 24, 2, 2], + [6, 32, 3, 2], + [6, 64, 4, 2], + [6, 96, 3, 1], + [6, 160, 3, 2], + [6, 320, 1, 1], + ] + + # building first layer + input_channel = _make_divisible(input_channel * width_mult, round_nearest) + self.out_channels = _make_divisible(last_channel * max(1.0, width_mult), round_nearest) + features = [ConvBNReLU(device_target, 3, input_channel, stride=2)] + # building inverted residual blocks + for t, c, n, s in self.cfgs: + output_channel = _make_divisible(c * width_mult, round_nearest) + for i in range(n): + stride = s if i == 0 else 1 + features.append(block(device_target, input_channel, output_channel, stride, expand_ratio=t)) + input_channel = output_channel + # building last several layers + features.append(ConvBNReLU(device_target, input_channel, self.out_channels, kernel_size=1)) + # make it nn.CellList + self.features = nn.SequentialCell(features) + # mobilenet head + head = ([GlobalAvgPooling(), nn.Dense(self.out_channels, num_classes, has_bias=True)] if not has_dropout else + [GlobalAvgPooling(), nn.Dropout(0.2), nn.Dense(self.out_channels, num_classes, has_bias=True)]) + self.head = nn.SequentialCell(head) + + self._initialize_weights() + + def construct(self, x): + x = self.features(x) + x = self.head(x) + return x + + def _initialize_weights(self): + """ + Initialize weights. + + Args: + + Returns: + None. + + Examples: + >>> _initialize_weights() + """ + self.init_parameters_data() + for _, m in self.cells_and_names(): + if isinstance(m, (nn.Conv2d, DepthwiseConv)): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.set_parameter_data(Tensor(np.random.normal(0, np.sqrt(2. / n), + m.weight.data.shape).astype("float32"))) + if m.bias is not None: + m.bias.set_parameter_data( + Tensor(np.zeros(m.bias.data.shape, dtype="float32"))) + elif isinstance(m, nn.BatchNorm2d): + m.gamma.set_parameter_data( + Tensor(np.ones(m.gamma.data.shape, dtype="float32"))) + m.beta.set_parameter_data( + Tensor(np.zeros(m.beta.data.shape, dtype="float32"))) + elif isinstance(m, nn.Dense): + m.weight.set_parameter_data(Tensor(np.random.normal( + 0, 0.01, m.weight.data.shape).astype("float32"))) + if m.bias is not None: + m.bias.set_parameter_data( + Tensor(np.zeros(m.bias.data.shape, dtype="float32"))) + + +def mobilenet_v2(**kwargs): + """ + Constructs a MobileNet V2 model + """ + return MobileNetV2(**kwargs) diff --git a/tutorials/tutorial_code/mobilenetv2_finetune/src/mobilenetV2_fusion.py b/tutorials/tutorial_code/mobilenetv2_finetune/src/mobilenetV2_fusion.py new file mode 100644 index 0000000000000000000000000000000000000000..715231d8fcadfe0342e0b3e0b29aedc09306d9a8 --- /dev/null +++ b/tutorials/tutorial_code/mobilenetv2_finetune/src/mobilenetV2_fusion.py @@ -0,0 +1,239 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +# """MobileNetV2 Quant model define""" + +import numpy as np +import mindspore.nn as nn +from mindspore.ops import operations as P +from mindspore import Tensor + +__all__ = ['mobilenetV2'] + + +def _make_divisible(v, divisor, min_value=None): + if min_value is None: + min_value = divisor + new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) + # Make sure that round down does not go down by more than 10 %. + if new_v < 0.9 * v: + new_v += divisor + return new_v + + +class GlobalAvgPooling(nn.Cell): + """ + Global avg pooling definition. + + Args: + + Returns: + Tensor, output tensor. + + Examples: + >>> GlobalAvgPooling() + """ + + def __init__(self): + super(GlobalAvgPooling, self).__init__() + self.mean = P.ReduceMean(keep_dims=False) + + def construct(self, x): + x = self.mean(x, (2, 3)) + return x + + +class ConvBNReLU(nn.Cell): + """ + Convolution/Depthwise fused with Batchnorm and ReLU block definition. + + Args: + in_planes (int): Input channel. + out_planes (int): Output channel. + kernel_size (int): Input kernel size. + stride (int): Stride size for the first convolutional layer. Default: 1. + groups (int): channel group. Convolution is 1 while Depthiwse is input channel. Default: 1. + + Returns: + Tensor, output tensor. + + Examples: + >>> ConvBNReLU(16, 256, kernel_size=1, stride=1, groups=1) + """ + + def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1): + super(ConvBNReLU, self).__init__() + padding = (kernel_size - 1) // 2 + self.conv = nn.Conv2dBnAct(in_planes, out_planes, kernel_size, + stride=stride, + pad_mode='pad', + padding=padding, + group=groups, + has_bn=True, + activation='relu') + + def construct(self, x): + x = self.conv(x) + return x + + +class InvertedResidual(nn.Cell): + """ + Mobilenetv2 residual block definition. + + Args: + inp (int): Input channel. + oup (int): Output channel. + stride (int): Stride size for the first convolutional layer. Default: 1. + expand_ratio (int): expand ration of input channel + + Returns: + Tensor, output tensor. + + Examples: + >>> ResidualBlock(3, 256, 1, 1) + """ + + def __init__(self, inp, oup, stride, expand_ratio): + super(InvertedResidual, self).__init__() + assert stride in [1, 2] + + hidden_dim = int(round(inp * expand_ratio)) + self.use_res_connect = stride == 1 and inp == oup + + layers = [] + if expand_ratio != 1: + layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1)) + layers.extend([ + ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim), + nn.Conv2dBnAct(hidden_dim, oup, kernel_size=1, stride=1, pad_mode='pad', padding=0, group=1, has_bn=True) + ]) + self.conv = nn.SequentialCell(layers) + self.add = P.TensorAdd() + + def construct(self, x): + out = self.conv(x) + if self.use_res_connect: + out = self.add(out, x) + return out + + +class mobilenetV2(nn.Cell): + """ + mobilenetV2 fusion architecture. + + Args: + class_num (Cell): number of classes. + width_mult (int): Channels multiplier for round to 8/16 and others. Default is 1. + has_dropout (bool): Is dropout used. Default is false + inverted_residual_setting (list): Inverted residual settings. Default is None + round_nearest (list): Channel round to . Default is 8 + Returns: + Tensor, output tensor. + + Examples: + >>> mobilenetV2(num_classes=1000) + """ + + def __init__(self, num_classes=1000, width_mult=1., + has_dropout=False, inverted_residual_setting=None, round_nearest=8): + super(mobilenetV2, self).__init__() + block = InvertedResidual + input_channel = 32 + last_channel = 1280 + # setting of inverted residual blocks + self.cfgs = inverted_residual_setting + if inverted_residual_setting is None: + self.cfgs = [ + # t, c, n, s + [1, 16, 1, 1], + [6, 24, 2, 2], + [6, 32, 3, 2], + [6, 64, 4, 2], + [6, 96, 3, 1], + [6, 160, 3, 2], + [6, 320, 1, 1], + ] + + # building first layer + input_channel = _make_divisible(input_channel * width_mult, round_nearest) + self.out_channels = _make_divisible(last_channel * max(1.0, width_mult), round_nearest) + + features = [ConvBNReLU(3, input_channel, stride=2)] + # building inverted residual blocks + for t, c, n, s in self.cfgs: + output_channel = _make_divisible(c * width_mult, round_nearest) + for i in range(n): + stride = s if i == 0 else 1 + features.append(block(input_channel, output_channel, stride, expand_ratio=t)) + input_channel = output_channel + # building last several layers + features.append(ConvBNReLU(input_channel, self.out_channels, kernel_size=1)) + # make it nn.CellList + self.features = nn.SequentialCell(features) + # mobilenet head + head = ([GlobalAvgPooling(), + nn.DenseBnAct(self.out_channels, num_classes, has_bias=True, has_bn=False) + ] if not has_dropout else + [GlobalAvgPooling(), + nn.Dropout(0.2), + nn.DenseBnAct(self.out_channels, num_classes, has_bias=True, has_bn=False) + ]) + self.head = nn.SequentialCell(head) + + # init weights + self._initialize_weights() + + def construct(self, x): + x = self.features(x) + x = self.head(x) + return x + + def _initialize_weights(self): + """ + Initialize weights. + + Args: + + Returns: + None. + + Examples: + >>> _initialize_weights() + """ + for _, m in self.cells_and_names(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + w = Tensor(np.random.normal(0, np.sqrt(2. / n), m.weight.data.shape).astype("float32")) + m.weight.set_parameter_data(w) + if m.bias is not None: + m.bias.set_parameter_data(Tensor(np.zeros(m.bias.data.shape, dtype="float32"))) + elif isinstance(m, nn.Conv2dBnAct): + n = m.conv.kernel_size[0] * m.conv.kernel_size[1] * m.conv.out_channels + w = Tensor(np.random.normal(0, np.sqrt(2. / n), m.conv.weight.data.shape).astype("float32")) + m.conv.weight.set_parameter_data(w) + if m.conv.bias is not None: + m.conv.bias.set_parameter_data(Tensor(np.zeros(m.conv.bias.data.shape, dtype="float32"))) + elif isinstance(m, nn.BatchNorm2d): + m.gamma.set_parameter_data(Tensor(np.ones(m.gamma.data.shape, dtype="float32"))) + m.beta.set_parameter_data(Tensor(np.zeros(m.beta.data.shape, dtype="float32"))) + elif isinstance(m, nn.Dense): + m.weight.set_parameter_data(Tensor(np.random.normal(0, 0.01, m.weight.data.shape).astype("float32"))) + if m.bias is not None: + m.bias.set_parameter_data(Tensor(np.zeros(m.bias.data.shape, dtype="float32"))) + elif isinstance(m, nn.DenseBnAct): + m.dense.weight.set_parameter_data( + Tensor(np.random.normal(0, 0.01, m.dense.weight.data.shape).astype("float32"))) + if m.dense.bias is not None: + m.dense.bias.set_parameter_data(Tensor(np.zeros(m.dense.bias.data.shape, dtype="float32"))) diff --git a/tutorials/tutorial_code/mobilenetv2_finetune/src/models.py b/tutorials/tutorial_code/mobilenetv2_finetune/src/models.py new file mode 100644 index 0000000000000000000000000000000000000000..ad0eff1686e8e81ec397574ec28194b76440f583 --- /dev/null +++ b/tutorials/tutorial_code/mobilenetv2_finetune/src/models.py @@ -0,0 +1,30 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +from mindspore.train.serialization import load_checkpoint, load_param_into_net + +def model_fine_tune(net, pretrained_path): + if pretrained_path is None: + return + param_dict = load_checkpoint(pretrained_path) + param_dict.pop('head.1.weight') + param_dict.pop('head.1.bias') + param_dict.pop('moments.head.1.weight') + param_dict.pop('moments.head.1.bias') + + load_param_into_net(net, param_dict) + + for param in net.get_parameters(): + if param.name not in ("head.1.weight","head.1.bias"): + param.requires_grad = False \ No newline at end of file diff --git a/tutorials/tutorial_code/mobilenetv2_finetune/src/utils.py b/tutorials/tutorial_code/mobilenetv2_finetune/src/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..ed1f1d639a467c55a23f988fa204e63e6990eb79 --- /dev/null +++ b/tutorials/tutorial_code/mobilenetv2_finetune/src/utils.py @@ -0,0 +1,142 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import time +import numpy as np + +from mindspore import context +from mindspore import nn +from mindspore import Tensor +from mindspore.common import dtype as mstype +from mindspore.train.model import ParallelMode +from mindspore.parallel._auto_parallel_context import auto_parallel_context +from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, Callback +from mindspore.communication.management import get_group_size, get_rank, init + + + +class Monitor(Callback): + """ + Monitor loss and time. + + Args: + lr_init (numpy array): train lr + + Returns: + None + + Examples: + >>> Monitor(100,lr_init=Tensor([0.05]*100).asnumpy()) + """ + + def __init__(self, lr_init=None): + super(Monitor, self).__init__() + self.lr_init = lr_init + self.lr_init_len = len(lr_init) + + def epoch_begin(self, run_context): + self.losses = [] + self.epoch_time = time.time() + + def epoch_end(self, run_context): + cb_params = run_context.original_args() + + epoch_mseconds = (time.time() - self.epoch_time) * 1000 + per_step_mseconds = epoch_mseconds / cb_params.batch_num + print("epoch time: {:5.3f}, per step time: {:5.3f}, avg loss: {:5.3f}".format(epoch_mseconds, + per_step_mseconds, + np.mean(self.losses))) + + def step_begin(self, run_context): + self.step_time = time.time() + + def step_end(self, run_context): + cb_params = run_context.original_args() + step_mseconds = (time.time() - self.step_time) * 1000 + step_loss = cb_params.net_outputs + + if isinstance(step_loss, (tuple, list)) and isinstance(step_loss[0], Tensor): + step_loss = step_loss[0] + if isinstance(step_loss, Tensor): + step_loss = np.mean(step_loss.asnumpy()) + + self.losses.append(step_loss) + cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + + print("epoch: [{:3d}/{:3d}], step:[{:5d}/{:5d}], loss:[{:5.3f}/{:5.3f}], time:[{:5.3f}], lr:[{:5.3f}]".format( + cb_params.cur_epoch_num - + 1, cb_params.epoch_num, cur_step_in_epoch, cb_params.batch_num, step_loss, + np.mean(self.losses), step_mseconds, self.lr_init[cb_params.cur_step_num - 1])) + +def switch_precision(net, data_type=mstype.float16): + net.to_float(data_type) + for _, cell in net.cells_and_names(): + if isinstance(cell, nn.Dense): + cell.to_float(mstype.float32) + + +def set_context(args, config): + + if args.platform == "GPU": + context.set_context(mode=context.GRAPH_MODE, + device_target="GPU", + save_graphs=False) + context.set_auto_parallel_context(device_num=config.device_num, + parallel_mode=ParallelMode.DATA_PARALLEL,mirror_mean=True) + + elif args.platform == "CPU": + context.set_context(mode=context.GRAPH_MODE, + device_target="CPU", + save_graphs=False) + + elif args.platform == "Ascend": + context.set_context(mode=context.GRAPH_MODE, + device_target="Ascend", + device_id=config.device_id, save_graphs=True) + if config.run_distribute: + context.set_auto_parallel_context(device_num=config.rank_size, parallel_mode=ParallelMode.DATA_PARALLEL, + parameter_broadcast=True, mirror_mean=True) + # auto_parallel_context().set_all_reduce_fusion_split_indices([140]) + else: + raise ValueError("Unsupport platform.") + +def config_ckpoint(args_opt, config, lr, step_size): + cb = None + if args_opt.platform in ("CPU", "GPU") or config.rank_id == 0: + cb = [Monitor(lr_init=lr.asnumpy())] + + if args_opt.save_checkpoint: + config_ck = CheckpointConfig(save_checkpoint_steps=args_opt.save_checkpoint_epochs * step_size, + keep_checkpoint_max=args_opt.keep_checkpoint_max) + ckpt_save_dir = args_opt.save_checkpoint_path + + #TODO CPU Multi CORE + # if args_opt.platform in ("CPU", "GPU"): + # ckpt_save_dir += "ckpt_" + str(get_rank()) + "/" + + if args_opt.platform == "GPU": + ckpt_save_dir += "ckpt_" + str(get_rank()) + "/" + + ckpt_cb = ModelCheckpoint(prefix="mobilenetV2", directory=ckpt_save_dir, config=config_ck) + cb += [ckpt_cb] + + return cb + + +def init_device(config): + + if (config.platform == "Ascend" and config.run_distribute) or (config.platform == "GPU"): + init(config.ccl) + else: + pass \ No newline at end of file diff --git a/tutorials/tutorial_code/mobilenetv2_finetune/train.py b/tutorials/tutorial_code/mobilenetv2_finetune/train.py new file mode 100644 index 0000000000000000000000000000000000000000..632c621ff41949f04bb5612fae2d5d11477ddd2d --- /dev/null +++ b/tutorials/tutorial_code/mobilenetv2_finetune/train.py @@ -0,0 +1,139 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""train_imagenet.""" +import os +import time +import argparse +import random +import numpy as np + +from mindspore import context +from mindspore import Tensor +from mindspore import nn +from mindspore.parallel._auto_parallel_context import auto_parallel_context +from mindspore.nn.optim.momentum import Momentum +from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits +from mindspore.nn.loss.loss import _Loss +from mindspore.ops import operations as P +from mindspore.ops import functional as F +from mindspore.common import dtype as mstype +from mindspore.train.model import Model, ParallelMode +from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, Callback +from mindspore.train.loss_scale_manager import FixedLossScaleManager +from mindspore.train.serialization import load_checkpoint, load_param_into_net +# from mindspore.communication.management import init, get_rank +from mindspore.communication.management import get_rank +import mindspore.dataset.engine as de + +from src.dataset import create_dataset +from src.lr_generator import get_lr +# from src.config import config_ascend, config_gpu +from src.config import set_config +from src.mobilenetV2 import mobilenet_v2 +from src.args import train_parse_args +from src.models import model_fine_tune +from src.utils import switch_precision, set_context, config_ckpoint, init_device + +random.seed(1) +np.random.seed(1) +de.config.set_seed(1) + + +class CrossEntropyWithLabelSmooth(_Loss): + """ + CrossEntropyWith LabelSmooth. + + Args: + smooth_factor (float): smooth factor, default=0. + num_classes (int): num classes + + Returns: + None. + + Examples: + >>> CrossEntropyWithLabelSmooth(smooth_factor=0., num_classes=1000) + """ + + def __init__(self, smooth_factor=0., num_classes=1000): + super(CrossEntropyWithLabelSmooth, self).__init__() + self.onehot = P.OneHot() + self.on_value = Tensor(1.0 - smooth_factor, mstype.float32) + self.off_value = Tensor(1.0 * smooth_factor / + (num_classes - 1), mstype.float32) + self.ce = nn.SoftmaxCrossEntropyWithLogits() + self.mean = P.ReduceMean(False) + self.cast = P.Cast() + + def construct(self, logit, label): + one_hot_label = self.onehot(self.cast(label, mstype.int32), F.shape(logit)[1], + self.on_value, self.off_value) + out_loss = self.ce(logit, one_hot_label) + out_loss = self.mean(out_loss, 0) + return out_loss + +if __name__ == '__main__': + + args_opt = train_parse_args() + config = set_config(args_opt) + set_context(args_opt, config) + + init_device(config) + + # define network + net = mobilenet_v2(num_classes=args_opt.num_classes, platform=args_opt.platform) + + print("train args: ", args_opt, "\ncfg: ", config) + if args_opt.platform=="Ascend": + print("parallel args: rank_id {}, device_id {}, rank_size {}".format(config.rank_id, config.device_id, config.rank_size)) + switch_precision(net, mstype.float16) + + # resume + model_fine_tune(net, args_opt.pretrain_checkpoint_path) + + # define loss + if args_opt.label_smooth > 0: + loss = CrossEntropyWithLabelSmooth( + smooth_factor=args_opt.label_smooth, num_classes=args_opt.num_classes) + else: + loss = SoftmaxCrossEntropyWithLogits( + is_grad=False, sparse=True, reduction='mean') + + # define dataset + epoch_size = args_opt.epoch_size + dataset = create_dataset(args = args_opt, + do_train=True, + repeat_num=1) + + step_size = dataset.get_dataset_size() + + # define optimizer + loss_scale = FixedLossScaleManager( + args_opt.loss_scale, drop_overflow_update=False) + lr = Tensor(get_lr(global_step=0, + lr_init=0, + lr_end=0, + lr_max=args_opt.lr, + warmup_epochs=args_opt.warmup_epochs, + total_epochs=epoch_size, + steps_per_epoch=step_size)) + opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, args_opt.momentum, + args_opt.weight_decay, args_opt.loss_scale) + # define model + model = Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale) + + cb = config_ckpoint(args_opt, config, lr, step_size) + + # begin train + model.train(epoch_size, dataset, callbacks=cb) \ No newline at end of file