From 1c1f76989df58d16c4802ea9afbb4f77529d531d Mon Sep 17 00:00:00 2001 From: majun121 <867479212@qq.com> Date: Tue, 17 May 2022 12:15:34 +0000 Subject: [PATCH] update DeepCTR_Series_for_TensorFlow/examples/run_classification_criteo.py. --- .../examples/run_classification_criteo.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/TensorFlow/built-in/recommendation/DeepCTR_Series_for_TensorFlow/examples/run_classification_criteo.py b/TensorFlow/built-in/recommendation/DeepCTR_Series_for_TensorFlow/examples/run_classification_criteo.py index 877c0a139..261089a40 100644 --- a/TensorFlow/built-in/recommendation/DeepCTR_Series_for_TensorFlow/examples/run_classification_criteo.py +++ b/TensorFlow/built-in/recommendation/DeepCTR_Series_for_TensorFlow/examples/run_classification_criteo.py @@ -38,6 +38,8 @@ from deepctr.models import DeepFM from deepctr.feature_column import SparseFeat, DenseFeat, get_feature_names import argparse +from tensorflow.keras.callbacks import ModelCheckpoint + if __name__ == "__main__": parser = argparse.ArgumentParser() @@ -106,8 +108,13 @@ if __name__ == "__main__": model.compile("adam", "binary_crossentropy", metrics=['binary_crossentropy'], ) + ######## ckpt ######### + checkpoint = ModelCheckpoint(filepath="../test/output/ckpt",save_best_only=True,verbose=1,period=1) history = model.fit(train_model_input, train[target].values, - batch_size=128, epochs=10, verbose=1, validation_split=0.2, ) + batch_size=128, epochs=10, verbose=1, validation_split=0.2, callbacks=[checkpoint]) + ######## ckpt ######### + #history = model.fit(train_model_input, train[target].values, + # batch_size=128, epochs=10, verbose=1, validation_split=0.2, ) pred_ans = model.predict(test_model_input, batch_size=8) print("test LogLoss", round(log_loss(test[target].values, pred_ans), 4)) print("test AUC", round(roc_auc_score(test[target].values, pred_ans), 4)) -- Gitee