diff --git a/TensorFlow/built-in/recommendation/DIN_ID0190_for_TensorFlow/deepctr/layers/core.py b/TensorFlow/built-in/recommendation/DIN_ID0190_for_TensorFlow/deepctr/layers/core.py index 1394805149be536c6885fed1b9f23f4530c5d34a..5ded7e3d3d1c3903cb340763b2c18dab0adefc4e 100644 --- a/TensorFlow/built-in/recommendation/DIN_ID0190_for_TensorFlow/deepctr/layers/core.py +++ b/TensorFlow/built-in/recommendation/DIN_ID0190_for_TensorFlow/deepctr/layers/core.py @@ -97,8 +97,8 @@ class LocalActivationUnit(Layer): query, keys = inputs keys_len = keys.get_shape()[1] - #queries = K.repeat_elements(query, keys_len, 1) - queries = tf.tile(query, [1, keys_len, 1]) + queries = K.repeat_elements(query, keys_len, 1) + #queries = tf.tile(query, [1, keys_len, 1]) att_input = tf.concat( [queries, keys, queries - keys, queries * keys], axis=-1) diff --git a/TensorFlow/built-in/recommendation/DIN_ID0190_for_TensorFlow/examples/din_demo.py b/TensorFlow/built-in/recommendation/DIN_ID0190_for_TensorFlow/examples/din_demo.py index c84ba337dc262ce4e20404b347a5535bc8738777..4af71c2bbc73e189d28937a49b3f3c3e8646095b 100644 --- a/TensorFlow/built-in/recommendation/DIN_ID0190_for_TensorFlow/examples/din_demo.py +++ b/TensorFlow/built-in/recommendation/DIN_ID0190_for_TensorFlow/examples/din_demo.py @@ -30,12 +30,12 @@ def split_tfrecord(tfrecord_path): def input_fn(filenames, is_train, batch_size=1024): def _parse_function(example_proto): feature_description = { - "movieId": tf.io.FixedLenFeature([1], np.int64), - "cateId": tf.io.FixedLenFeature([1], np.int64), - "hist_movieId": tf.io.FixedLenFeature([200], np.int64), - "hist_cateId": tf.io.FixedLenFeature([200], np.int64), - "seq_length": tf.io.FixedLenFeature([1], np.int64), - "label": tf.io.FixedLenFeature([1], np.float32) + "movieId": tf.io.FixedLenFeature([1], tf.int64), + "cateId": tf.io.FixedLenFeature([1], tf.int64), + "hist_movieId": tf.io.FixedLenFeature([200], tf.int64), + "hist_cateId": tf.io.FixedLenFeature([200], tf.int64), + "seq_length": tf.io.FixedLenFeature([1], tf.int64), + "label": tf.io.FixedLenFeature([1], tf.float32) } features = tf.io.parse_example(example_proto, feature_description) labels = features.pop("label") diff --git a/TensorFlow/built-in/recommendation/DIN_ID0190_for_TensorFlow/test/train_full_8p.sh b/TensorFlow/built-in/recommendation/DIN_ID0190_for_TensorFlow/test/train_full_8p.sh index 5c63b6d7c9ba1e01954b84fb226d358f67e162ae..af22b772227fb50ae640d355ebc3f75c7f249add 100644 --- a/TensorFlow/built-in/recommendation/DIN_ID0190_for_TensorFlow/test/train_full_8p.sh +++ b/TensorFlow/built-in/recommendation/DIN_ID0190_for_TensorFlow/test/train_full_8p.sh @@ -5,7 +5,7 @@ cur_path=`pwd` export RANK_SIZE=8 export RANK_TABLE_FILE=$cur_path/rank_table_8p.json export JOB_ID=10087 -export OP_NO_REUSE_MEM=StridedSliceD +#export OP_NO_REUSE_MEM=StridedSliceD #export ASCEND_DEVICE_ID= diff --git a/TensorFlow/built-in/recommendation/DIN_ID0190_for_TensorFlow/test/train_performance_1p.sh b/TensorFlow/built-in/recommendation/DIN_ID0190_for_TensorFlow/test/train_performance_1p.sh index 57afc90d52fb87da4920f30d501cf3f389400034..88f727c254b61697c1a5ab46eb9af5301ffa7329 100644 --- a/TensorFlow/built-in/recommendation/DIN_ID0190_for_TensorFlow/test/train_performance_1p.sh +++ b/TensorFlow/built-in/recommendation/DIN_ID0190_for_TensorFlow/test/train_performance_1p.sh @@ -6,7 +6,7 @@ export RANK_SIZE=1 export JOB_ID=10087 #export ASCEND_DEVICE_ID= -export OP_NO_REUSE_MEM=StridedSliceD +#export OP_NO_REUSE_MEM=StridedSliceD # 数据集路径,保持为空,不需要修改 data_path="" diff --git a/TensorFlow/built-in/recommendation/DIN_ID0190_for_TensorFlow/test/train_performance_8p.sh b/TensorFlow/built-in/recommendation/DIN_ID0190_for_TensorFlow/test/train_performance_8p.sh index 2f8698292047483cc9f2b6bb401198773dd41b31..9076c5523d5731d5ff86896a51b23ec3283219ba 100644 --- a/TensorFlow/built-in/recommendation/DIN_ID0190_for_TensorFlow/test/train_performance_8p.sh +++ b/TensorFlow/built-in/recommendation/DIN_ID0190_for_TensorFlow/test/train_performance_8p.sh @@ -5,7 +5,7 @@ cur_path=`pwd` export RANK_SIZE=8 export RANK_TABLE_FILE=$cur_path/rank_table_8p.json export JOB_ID=10087 -export OP_NO_REUSE_MEM=StridedSliceD +#export OP_NO_REUSE_MEM=StridedSliceD #export ASCEND_DEVICE_ID=