From 2bee538f8d8e061f740f113572e3a01b0b8ad01b Mon Sep 17 00:00:00 2001 From: limingxing517vim Date: Mon, 22 May 2023 09:50:14 +0800 Subject: [PATCH] update --- .../DIN_ID0190_for_TensorFlow/deepctr/layers/core.py | 4 ++-- .../DIN_ID0190_for_TensorFlow/examples/din_demo.py | 12 ++++++------ .../DIN_ID0190_for_TensorFlow/test/train_full_8p.sh | 2 +- .../test/train_performance_1p.sh | 2 +- .../test/train_performance_8p.sh | 2 +- 5 files changed, 11 insertions(+), 11 deletions(-) diff --git a/TensorFlow/built-in/recommendation/DIN_ID0190_for_TensorFlow/deepctr/layers/core.py b/TensorFlow/built-in/recommendation/DIN_ID0190_for_TensorFlow/deepctr/layers/core.py index 139480514..5ded7e3d3 100644 --- a/TensorFlow/built-in/recommendation/DIN_ID0190_for_TensorFlow/deepctr/layers/core.py +++ b/TensorFlow/built-in/recommendation/DIN_ID0190_for_TensorFlow/deepctr/layers/core.py @@ -97,8 +97,8 @@ class LocalActivationUnit(Layer): query, keys = inputs keys_len = keys.get_shape()[1] - #queries = K.repeat_elements(query, keys_len, 1) - queries = tf.tile(query, [1, keys_len, 1]) + queries = K.repeat_elements(query, keys_len, 1) + #queries = tf.tile(query, [1, keys_len, 1]) att_input = tf.concat( [queries, keys, queries - keys, queries * keys], axis=-1) diff --git a/TensorFlow/built-in/recommendation/DIN_ID0190_for_TensorFlow/examples/din_demo.py b/TensorFlow/built-in/recommendation/DIN_ID0190_for_TensorFlow/examples/din_demo.py index c84ba337d..4af71c2bb 100644 --- a/TensorFlow/built-in/recommendation/DIN_ID0190_for_TensorFlow/examples/din_demo.py +++ b/TensorFlow/built-in/recommendation/DIN_ID0190_for_TensorFlow/examples/din_demo.py @@ -30,12 +30,12 @@ def split_tfrecord(tfrecord_path): def input_fn(filenames, is_train, batch_size=1024): def _parse_function(example_proto): feature_description = { - "movieId": tf.io.FixedLenFeature([1], np.int64), - "cateId": tf.io.FixedLenFeature([1], np.int64), - "hist_movieId": tf.io.FixedLenFeature([200], np.int64), - "hist_cateId": tf.io.FixedLenFeature([200], np.int64), - "seq_length": tf.io.FixedLenFeature([1], np.int64), - "label": tf.io.FixedLenFeature([1], np.float32) + "movieId": tf.io.FixedLenFeature([1], tf.int64), + "cateId": tf.io.FixedLenFeature([1], tf.int64), + "hist_movieId": tf.io.FixedLenFeature([200], tf.int64), + "hist_cateId": tf.io.FixedLenFeature([200], tf.int64), + "seq_length": tf.io.FixedLenFeature([1], tf.int64), + "label": tf.io.FixedLenFeature([1], tf.float32) } features = tf.io.parse_example(example_proto, feature_description) labels = features.pop("label") diff --git a/TensorFlow/built-in/recommendation/DIN_ID0190_for_TensorFlow/test/train_full_8p.sh b/TensorFlow/built-in/recommendation/DIN_ID0190_for_TensorFlow/test/train_full_8p.sh index 5c63b6d7c..af22b7722 100644 --- a/TensorFlow/built-in/recommendation/DIN_ID0190_for_TensorFlow/test/train_full_8p.sh +++ b/TensorFlow/built-in/recommendation/DIN_ID0190_for_TensorFlow/test/train_full_8p.sh @@ -5,7 +5,7 @@ cur_path=`pwd` export RANK_SIZE=8 export RANK_TABLE_FILE=$cur_path/rank_table_8p.json export JOB_ID=10087 -export OP_NO_REUSE_MEM=StridedSliceD +#export OP_NO_REUSE_MEM=StridedSliceD #export ASCEND_DEVICE_ID= diff --git a/TensorFlow/built-in/recommendation/DIN_ID0190_for_TensorFlow/test/train_performance_1p.sh b/TensorFlow/built-in/recommendation/DIN_ID0190_for_TensorFlow/test/train_performance_1p.sh index 57afc90d5..88f727c25 100644 --- a/TensorFlow/built-in/recommendation/DIN_ID0190_for_TensorFlow/test/train_performance_1p.sh +++ b/TensorFlow/built-in/recommendation/DIN_ID0190_for_TensorFlow/test/train_performance_1p.sh @@ -6,7 +6,7 @@ export RANK_SIZE=1 export JOB_ID=10087 #export ASCEND_DEVICE_ID= -export OP_NO_REUSE_MEM=StridedSliceD +#export OP_NO_REUSE_MEM=StridedSliceD # 数据集路径,保持为空,不需要修改 data_path="" diff --git a/TensorFlow/built-in/recommendation/DIN_ID0190_for_TensorFlow/test/train_performance_8p.sh b/TensorFlow/built-in/recommendation/DIN_ID0190_for_TensorFlow/test/train_performance_8p.sh index 2f8698292..9076c5523 100644 --- a/TensorFlow/built-in/recommendation/DIN_ID0190_for_TensorFlow/test/train_performance_8p.sh +++ b/TensorFlow/built-in/recommendation/DIN_ID0190_for_TensorFlow/test/train_performance_8p.sh @@ -5,7 +5,7 @@ cur_path=`pwd` export RANK_SIZE=8 export RANK_TABLE_FILE=$cur_path/rank_table_8p.json export JOB_ID=10087 -export OP_NO_REUSE_MEM=StridedSliceD +#export OP_NO_REUSE_MEM=StridedSliceD #export ASCEND_DEVICE_ID= -- Gitee