diff --git a/tutorials/source_en/quick_start/quick_start.md b/tutorials/source_en/quick_start/quick_start.md index 0ad4925190e0871799c9addc858df112243c53b0..801aa59c62f96c64fc44091af3e335b4d04a390e 100644 --- a/tutorials/source_en/quick_start/quick_start.md +++ b/tutorials/source_en/quick_start/quick_start.md @@ -100,6 +100,7 @@ if __name__ == "__main__": help='device where the code will be implemented (default: CPU)') args = parser.parse_args() context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) + dataset_sink_mode = not args.device_target == "CPU" ... ``` @@ -338,12 +339,12 @@ from mindspore.train.callback import LossMonitor from mindspore.train import Model ... -def train_net(args, model, epoch_size, mnist_path, repeat_size, ckpoint_cb): +def train_net(args, model, epoch_size, mnist_path, repeat_size, ckpoint_cb, sink_mode): """define the training method""" print("============== Starting Training ==============") #load training dataset ds_train = create_dataset(os.path.join(mnist_path, "train"), 32, repeat_size) - model.train(epoch_size, ds_train, callbacks=[ckpoint_cb, LossMonitor()], dataset_sink_mode=False) # train + model.train(epoch_size, ds_train, callbacks=[ckpoint_cb, LossMonitor()], dataset_sink_mode=sink_mode) # train ... if __name__ == "__main__": @@ -353,7 +354,7 @@ if __name__ == "__main__": mnist_path = "./MNIST_Data" repeat_size = epoch_size model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()}) - train_net(args, model, epoch_size, mnist_path, repeat_size, ckpoint_cb) + train_net(args, model, epoch_size, mnist_path, repeat_size, ckpoint_cb, dataset_sink_mode) ... ``` In the preceding information: diff --git a/tutorials/source_zh_cn/quick_start/quick_start.md b/tutorials/source_zh_cn/quick_start/quick_start.md index f4ffebd70e9d5c714bea5fe571313d09e362a36b..0e83e6752d95ea7dba779fa7251d7ffc8d37e5c2 100644 --- a/tutorials/source_zh_cn/quick_start/quick_start.md +++ b/tutorials/source_zh_cn/quick_start/quick_start.md @@ -102,6 +102,7 @@ if __name__ == "__main__": help='device where the code will be implemented (default: CPU)') args = parser.parse_args() context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) + dataset_sink_mode = not args.device_target == "CPU" ... ``` @@ -339,12 +340,12 @@ from mindspore.train.callback import LossMonitor from mindspore.train import Model ... -def train_net(args, model, epoch_size, mnist_path, repeat_size, ckpoint_cb): +def train_net(args, model, epoch_size, mnist_path, repeat_size, ckpoint_cb, sink_mode): """define the training method""" print("============== Starting Training ==============") #load training dataset ds_train = create_dataset(os.path.join(mnist_path, "train"), 32, repeat_size) - model.train(epoch_size, ds_train, callbacks=[ckpoint_cb, LossMonitor()], dataset_sink_mode=False) + model.train(epoch_size, ds_train, callbacks=[ckpoint_cb, LossMonitor()], dataset_sink_mode=sink_mode) ... if __name__ == "__main__": @@ -354,7 +355,7 @@ if __name__ == "__main__": mnist_path = "./MNIST_Data" repeat_size = epoch_size model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()}) - train_net(args, model, epoch_size, mnist_path, repeat_size, ckpoint_cb) + train_net(args, model, epoch_size, mnist_path, repeat_size, ckpoint_cb, dataset_sink_mode) ... ``` 其中, diff --git a/tutorials/tutorial_code/lenet.py b/tutorials/tutorial_code/lenet.py index 84d0a9c54afd3d6cb4f5904b70709644b4a8753a..441f423360c179140f23767970e865e179a10a6c 100644 --- a/tutorials/tutorial_code/lenet.py +++ b/tutorials/tutorial_code/lenet.py @@ -169,12 +169,12 @@ class LeNet5(nn.Cell): return x -def train_net(args, model, epoch_size, mnist_path, repeat_size, ckpoint_cb): +def train_net(args, model, epoch_size, mnist_path, repeat_size, ckpoint_cb, sink_mode): """Define the training method.""" print("============== Starting Training ==============") # load training dataset ds_train = create_dataset(os.path.join(mnist_path, "train"), 32, repeat_size) - model.train(epoch_size, ds_train, callbacks=[ckpoint_cb, LossMonitor()], dataset_sink_mode=False) + model.train(epoch_size, ds_train, callbacks=[ckpoint_cb, LossMonitor()], dataset_sink_mode=sink_mode) def test_net(args, network, model, mnist_path): @@ -196,6 +196,7 @@ if __name__ == "__main__": help='device where the code will be implemented (default: CPU)') args = parser.parse_args() context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) + dataset_sink_mode = not args.device_target == "CPU" # download mnist dataset download_dataset() # learning rate setting @@ -216,5 +217,5 @@ if __name__ == "__main__": # group layers into an object with training and evaluation features model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()}) - train_net(args, model, epoch_size, mnist_path, repeat_size, ckpoint_cb) + train_net(args, model, epoch_size, mnist_path, repeat_size, ckpoint_cb, dataset_sink_mode) test_net(args, network, model, mnist_path)