From ab5c5ae6ed29d255035baa96c2edb19f42d92bf6 Mon Sep 17 00:00:00 2001 From: "wangnan39@huawei.com" Date: Tue, 12 May 2020 16:38:21 +0800 Subject: [PATCH] fix import error in docs that use on the cloud --- .../advanced_use/use_on_the_cloud.md | 25 +++++++++++-------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/tutorials/source_zh_cn/advanced_use/use_on_the_cloud.md b/tutorials/source_zh_cn/advanced_use/use_on_the_cloud.md index 8acca15e18..172505d597 100644 --- a/tutorials/source_zh_cn/advanced_use/use_on_the_cloud.md +++ b/tutorials/source_zh_cn/advanced_use/use_on_the_cloud.md @@ -89,7 +89,7 @@ ModelArts使用对象存储服务(Object Storage Service,简称OBS)进行 1. 在ModelArts运行的脚本必须配置`data_url`和`train_url`,分别对应数据存储路径(OBS路径)和训练输出路径(OBS路径)。 ``` python - import parser + import argparse parser = argparse.ArgumentParser(description='ResNet-50 train.') parser.add_argument('--data_url', required=True, default=None, help='Location of data.') @@ -160,6 +160,8 @@ MindSpore暂时没有提供直接访问OBS数据的接口,需要通过MoXing ```python import os + from mindspore import context + from mindspore.train.model import ParallelMode device_num = int(os.getenv('RANK_SIZE')) if device_num > 1: @@ -176,6 +178,7 @@ MindSpore暂时没有提供直接访问OBS数据的接口,需要通过MoXing ``` python import os +import argparse from mindspore import context from mindspore.train.model import ParallelMode import mindspore.dataset.engine as de @@ -194,8 +197,8 @@ def create_dataset(dataset_path): def resnet50_train(args_opt): if device_num > 1: context.set_auto_parallel_context(device_num=device_num, - parallel_mode=ParallelMode.DATA_PARALLEL, - mirror_mean=True) + parallel_mode=ParallelMode.DATA_PARALLEL, + mirror_mean=True) train_dataset = create_dataset(local_data_path) if __name__ == '__main__': @@ -212,10 +215,12 @@ if __name__ == '__main__': ``` python import os +import argparse from mindspore import context from mindspore.train.model import ParallelMode import mindspore.dataset.engine as de +# adapt to cloud: used for downloading data import moxing as mox device_id = int(os.getenv('DEVICE_ID')) @@ -230,19 +235,17 @@ def create_dataset(dataset_path): return ds def resnet50_train(args_opt): - epoch_size = args_opt.epoch_size - # define local data path + # adapt to cloud: define local data path local_data_path = '/cache/data' - context.set_context(mode=context.GRAPH_MODE) if device_num > 1: context.set_auto_parallel_context(device_num=device_num, - parallel_mode=ParallelMode.DATA_PARALLEL, - mirror_mean=True) - # define distributed local data path + parallel_mode=ParallelMode.DATA_PARALLEL, + mirror_mean=True) + # adapt to cloud: define distributed local data path local_data_path = os.path.join(local_data_path, str(device_id)) - # data download + # adapt to cloud: download data from obs to local location print('Download data.') mox.file.copy_parallel(src_url=args_opt.data_url, dst_url=local_data_path) @@ -250,7 +253,9 @@ def resnet50_train(args_opt): if __name__ == '__main__': parser = argparse.ArgumentParser(description='ResNet-50 train.') + # adapt to cloud: get obs data path parser.add_argument('--data_url', required=True, default=None, help='Location of data.') + # adapt to cloud: get obs output path parser.add_argument('--train_url', required=True, default=None, help='Location of training outputs.') parser.add_argument('--epoch_size', type=int, default=90, help='Train epoch size.') args_opt, unknown = parser.parse_known_args() -- Gitee