diff --git a/tutorials/source_en/quick_start/quick_start.md b/tutorials/source_en/quick_start/quick_start.md index 0ad4925190e0871799c9addc858df112243c53b0..420f1e56955c3bced564dc6d40afec1d3bb8f3c5 100644 --- a/tutorials/source_en/quick_start/quick_start.md +++ b/tutorials/source_en/quick_start/quick_start.md @@ -343,7 +343,7 @@ def train_net(args, model, epoch_size, mnist_path, repeat_size, ckpoint_cb): print("============== Starting Training ==============") #load training dataset ds_train = create_dataset(os.path.join(mnist_path, "train"), 32, repeat_size) - model.train(epoch_size, ds_train, callbacks=[ckpoint_cb, LossMonitor()], dataset_sink_mode=False) # train + model.train(epoch_size, ds_train, callbacks=[ckpoint_cb, LossMonitor()], dataset_sink_mode=True) # train ... if __name__ == "__main__": diff --git a/tutorials/source_zh_cn/quick_start/quick_start.md b/tutorials/source_zh_cn/quick_start/quick_start.md index f4ffebd70e9d5c714bea5fe571313d09e362a36b..58d4f35b62902e300c556054198c34cdb17bc834 100644 --- a/tutorials/source_zh_cn/quick_start/quick_start.md +++ b/tutorials/source_zh_cn/quick_start/quick_start.md @@ -344,7 +344,7 @@ def train_net(args, model, epoch_size, mnist_path, repeat_size, ckpoint_cb): print("============== Starting Training ==============") #load training dataset ds_train = create_dataset(os.path.join(mnist_path, "train"), 32, repeat_size) - model.train(epoch_size, ds_train, callbacks=[ckpoint_cb, LossMonitor()], dataset_sink_mode=False) + model.train(epoch_size, ds_train, callbacks=[ckpoint_cb, LossMonitor()], dataset_sink_mode=True) ... if __name__ == "__main__": diff --git a/tutorials/tutorial_code/lenet.py b/tutorials/tutorial_code/lenet.py index 84d0a9c54afd3d6cb4f5904b70709644b4a8753a..e2d5d03d38c5f08450efe0e60d8a945088692ef6 100644 --- a/tutorials/tutorial_code/lenet.py +++ b/tutorials/tutorial_code/lenet.py @@ -174,7 +174,7 @@ def train_net(args, model, epoch_size, mnist_path, repeat_size, ckpoint_cb): print("============== Starting Training ==============") # load training dataset ds_train = create_dataset(os.path.join(mnist_path, "train"), 32, repeat_size) - model.train(epoch_size, ds_train, callbacks=[ckpoint_cb, LossMonitor()], dataset_sink_mode=False) + model.train(epoch_size, ds_train, callbacks=[ckpoint_cb, LossMonitor()], dataset_sink_mode=True) def test_net(args, network, model, mnist_path):