From 10b5b1f7d7edf642d381c10dbe9e40982fb02986 Mon Sep 17 00:00:00 2001 From: wukesong Date: Wed, 3 Jun 2020 13:00:18 +0800 Subject: [PATCH] modify lenet dataset_sink_mode=True --- tutorials/source_en/quick_start/quick_start.md | 2 +- tutorials/source_zh_cn/quick_start/quick_start.md | 2 +- tutorials/tutorial_code/lenet.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tutorials/source_en/quick_start/quick_start.md b/tutorials/source_en/quick_start/quick_start.md index 0ad4925190..420f1e5695 100644 --- a/tutorials/source_en/quick_start/quick_start.md +++ b/tutorials/source_en/quick_start/quick_start.md @@ -343,7 +343,7 @@ def train_net(args, model, epoch_size, mnist_path, repeat_size, ckpoint_cb): print("============== Starting Training ==============") #load training dataset ds_train = create_dataset(os.path.join(mnist_path, "train"), 32, repeat_size) - model.train(epoch_size, ds_train, callbacks=[ckpoint_cb, LossMonitor()], dataset_sink_mode=False) # train + model.train(epoch_size, ds_train, callbacks=[ckpoint_cb, LossMonitor()], dataset_sink_mode=True) # train ... if __name__ == "__main__": diff --git a/tutorials/source_zh_cn/quick_start/quick_start.md b/tutorials/source_zh_cn/quick_start/quick_start.md index f4ffebd70e..58d4f35b62 100644 --- a/tutorials/source_zh_cn/quick_start/quick_start.md +++ b/tutorials/source_zh_cn/quick_start/quick_start.md @@ -344,7 +344,7 @@ def train_net(args, model, epoch_size, mnist_path, repeat_size, ckpoint_cb): print("============== Starting Training ==============") #load training dataset ds_train = create_dataset(os.path.join(mnist_path, "train"), 32, repeat_size) - model.train(epoch_size, ds_train, callbacks=[ckpoint_cb, LossMonitor()], dataset_sink_mode=False) + model.train(epoch_size, ds_train, callbacks=[ckpoint_cb, LossMonitor()], dataset_sink_mode=True) ... if __name__ == "__main__": diff --git a/tutorials/tutorial_code/lenet.py b/tutorials/tutorial_code/lenet.py index 84d0a9c54a..e2d5d03d38 100644 --- a/tutorials/tutorial_code/lenet.py +++ b/tutorials/tutorial_code/lenet.py @@ -174,7 +174,7 @@ def train_net(args, model, epoch_size, mnist_path, repeat_size, ckpoint_cb): print("============== Starting Training ==============") # load training dataset ds_train = create_dataset(os.path.join(mnist_path, "train"), 32, repeat_size) - model.train(epoch_size, ds_train, callbacks=[ckpoint_cb, LossMonitor()], dataset_sink_mode=False) + model.train(epoch_size, ds_train, callbacks=[ckpoint_cb, LossMonitor()], dataset_sink_mode=True) def test_net(args, network, model, mnist_path): -- Gitee