diff --git a/recommendation/ctr/dlrm/paddlepaddle/README.md b/recommendation/ctr/dlrm/paddlepaddle/README.md new file mode 100644 index 0000000000000000000000000000000000000000..3758fc2b5ba1f625a94c23b87e431ef73c061190 --- /dev/null +++ b/recommendation/ctr/dlrm/paddlepaddle/README.md @@ -0,0 +1,46 @@ +# DLRM + +## Model description +With the advent of deep learning, neural network-based recommendation models have emerged as an important tool for tackling personalization and recommendation tasks. These networks differ significantly from other deep learning networks due to their need to handle categorical features and are not well studied or understood. In this paper, we develop a state-of-the-art deep learning recommendation model (DLRM) and provide its implementation in both PyTorch and Caffe2 frameworks. In addition, we design a specialized parallelization scheme utilizing model parallelism on the embedding tables to mitigate memory constraints while exploiting data parallelism to scale-out compute from the fully-connected layers. We compare DLRM against existing recommendation models and characterize its performance on the Big Basin AI platform, demonstrating its usefulness as a benchmark for future algorithmic experimentation and system co-design. + +## Step 1: Installing + +```bash +git clone -b master --recursive https://github.com/PaddlePaddle/PaddleRec.git +cd PaddleRec +git checkout eb869a15b7d858f9f3788d9b25af4f61a022f9c4 +pip3 install -r requirements.txt + +cp ../net.py ../config_bigdata_multi_cards.yaml ./models/rank/dlrm +cp ../trainer.py tools/trainer.py +``` + +Replace net.py and trainer.py, add config_bigdata_multi_cards.yaml with use_fleet to support training in multiple cards. + + +## Step 2: Download data + +```bash +cd datasets/criteo +sh run.sh +``` + +## Step 3: Run DLRM + +### one gpu + +```bash +# Make sure your dataset path is the same as above +cd ../../models/rank/dlrm +python3 -u ../../../tools/trainer.py -m config_bigdata.yaml + +``` + +### 8 gpus + +```bash +# Make sure your dataset path is the same as above +cd ../../models/rank/dlrm +export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 +python3 -u -m paddle.distributed.launch --gpus 0,1,2,3,4,5,6,7 ../../../tools/trainer.py -m config_bigdata_multi_cards.yaml +``` \ No newline at end of file diff --git a/recommendation/ctr/dlrm/paddlepaddle/config_bigdata_multi_cards.yaml b/recommendation/ctr/dlrm/paddlepaddle/config_bigdata_multi_cards.yaml new file mode 100644 index 0000000000000000000000000000000000000000..66ab297a403d74550c950d87461a112083ed83cb --- /dev/null +++ b/recommendation/ctr/dlrm/paddlepaddle/config_bigdata_multi_cards.yaml @@ -0,0 +1,58 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# workspace +#workspace: "models/rank/dlrm" + + +runner: + train_data_dir: "../../../datasets/criteo/slot_train_data_full" + train_reader_path: "criteo_reader" # importlib format + use_gpu: True + use_xpu: False # Enable this option only if you have an xpu device + use_auc: True + train_batch_size: 2048 + epochs: 1 + print_interval: 10 + model_save_path: "output_model_dlrm_all" + infer_batch_size: 2048 + infer_reader_path: "criteo_reader" # importlib format + test_data_dir: "../../../datasets/criteo/slot_test_data_full" + infer_load_path: "output_model_dlrm_all" + infer_start_epoch: 0 + infer_end_epoch: 1 + + # distribute_config + sync_mode: "async" + split_file_list: False + thread_num: 1 + num_workers: 0 + use_fleet: True + +# hyper parameters of user-defined network +hyper_parameters: + # optimizer config + optimizer: + class: SGD + learning_rate: 0.1 + strategy: async + # user-defined pairs + sparse_inputs_slots: 27 + + dense_input_dim: 13 + bot_layer_sizes: [512, 256, 64, 16] + sparse_feature_number: 1000001 + sparse_feature_dim: 16 + top_layer_sizes: [512, 256, 2] + num_field: 26 diff --git a/recommendation/ctr/dlrm/paddlepaddle/net.py b/recommendation/ctr/dlrm/paddlepaddle/net.py new file mode 100644 index 0000000000000000000000000000000000000000..e2ca974ff9fbd4030bd9f3073190578883ba6e77 --- /dev/null +++ b/recommendation/ctr/dlrm/paddlepaddle/net.py @@ -0,0 +1,171 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.nn as nn +import math +import numpy as np + +MIN_FLOAT = np.finfo(np.float32).min / 100.0 + + +class DLRMLayer(nn.Layer): + """Dot interaction layer. + + See theory in the DLRM paper: https://arxiv.org/pdf/1906.00091.pdf, + section 2.1.3. Sparse activations and dense activations are combined. + Dot interaction is applied to a batch of input Tensors [e1,...,e_k] of the + same dimension and the output is a batch of Tensors with all distinct pairwise + dot products of the form dot(e_i, e_j) for i <= j if self self_interaction is + True, otherwise dot(e_i, e_j) i < j. + """ + + def __init__(self, + dense_feature_dim, + bot_layer_sizes, + sparse_feature_number, + sparse_feature_dim, + top_layer_sizes, + num_field, + sync_mode=None, + self_interaction=False): + super(DLRMLayer, self).__init__() + self.dense_feature_dim = dense_feature_dim + self.bot_layer_sizes = bot_layer_sizes + self.sparse_feature_number = sparse_feature_number + self.sparse_feature_dim = sparse_feature_dim + self.top_layer_sizes = top_layer_sizes + self.num_field = num_field + self.self_interaction = self_interaction + + self.bot_mlp = MLPLayer( + input_shape=dense_feature_dim, + units_list=bot_layer_sizes, + activation="relu") + # `num_features * (num_features + 1) / 2` if self_interaction is True and + # `num_features * (num_features - 1) / 2` if self_interaction is False. + self.concat_size = int((num_field + 1) * (num_field + 2) / 2) if self.self_interaction \ + else int(num_field * (num_field + 1) / 2) + self.top_mlp = MLPLayer( + input_shape=self.concat_size + sparse_feature_dim, + units_list=top_layer_sizes) + + use_sparse = True + # if paddle.is_compiled_with_custom_device('npu'): + # use_sparse = False + + self.embedding = paddle.nn.Embedding( + num_embeddings=self.sparse_feature_number, + embedding_dim=self.sparse_feature_dim, + sparse=use_sparse, + weight_attr=paddle.ParamAttr( + name="SparseFeatFactors", + initializer=paddle.nn.initializer.TruncatedNormal())) + + def forward(self, sparse_inputs, dense_inputs): + """Performs the interaction operation on the tensors in the list. + + Args: + sparse_inputs: sparse categorical features, (batch_size, sparse_num_field) + dense_inputs: dense features, (batch_size, dense_feature_dim) + + Returns: predictions + """ + # (batch_size, sparse_feature_dim) + x = self.bot_mlp(dense_inputs) + batch_size, d = x.shape + + sparse_embs = [] + for s_input in sparse_inputs: + emb = self.embedding(s_input) + emb = paddle.reshape( + emb, shape=[batch_size, self.sparse_feature_dim]) + sparse_embs.append(emb) + + # concat dense embedding and sparse embeddings, (batch_size, (sparse_num_field + 1), embedding_size) + T = paddle.reshape( + paddle.concat( + x=sparse_embs + [x], axis=1), + (batch_size, self.num_field + 1, d)) + + # interact features, select upper-triangular portion + Z = paddle.bmm(T, paddle.transpose(T, perm=[0, 2, 1])) + + Zflat = paddle.triu(Z, 1) + paddle.tril( + x=paddle.ones_like(Z) * MIN_FLOAT, + diagonal=-1 if self.self_interaction else 0) + Zflat = paddle.reshape( + x=paddle.masked_select(Zflat, + paddle.greater_than( + Zflat, + paddle.ones_like(Zflat) * MIN_FLOAT)), + shape=(batch_size, self.concat_size)) + + R = paddle.concat([x] + [Zflat], axis=1) + + y = self.top_mlp(R) + return y + + +class MLPLayer(nn.Layer): + def __init__(self, input_shape, units_list=None, activation=None, + **kwargs): + super(MLPLayer, self).__init__(**kwargs) + + if units_list is None: + units_list = [128, 128, 64] + units_list = [input_shape] + units_list + + self.units_list = units_list + self.mlp = [] + self.activation = activation + + for i, unit in enumerate(units_list[:-1]): + if i != len(units_list) - 1: + dense = paddle.nn.Linear( + in_features=unit, + out_features=units_list[i + 1], + weight_attr=paddle.ParamAttr( + initializer=paddle.nn.initializer.TruncatedNormal( + std=1.0 / math.sqrt(unit)))) + self.mlp.append(dense) + self.add_sublayer('dense_%d' % i, dense) + + relu = paddle.nn.ReLU() + self.mlp.append(relu) + self.add_sublayer('relu_%d' % i, relu) + + norm = paddle.nn.BatchNorm1D(units_list[i + 1]) + self.mlp.append(norm) + self.add_sublayer('norm_%d' % i, norm) + else: + dense = paddle.nn.Linear( + in_features=unit, + out_features=units_list[i + 1], + weight_attr=paddle.ParamAttr( + initializer=paddle.nn.initializer.TruncatedNormal( + std=1.0 / math.sqrt(unit)))) + self.mlp.append(dense) + self.add_sublayer('dense_%d' % i, dense) + + if self.activation is not None: + relu = paddle.nn.ReLU() + self.mlp.append(relu) + self.add_sublayer('relu_%d' % i, relu) + + def forward(self, inputs): + outputs = inputs + for n_layer in self.mlp: + outputs = n_layer(outputs) + return outputs diff --git a/recommendation/ctr/dlrm/paddlepaddle/trainer.py b/recommendation/ctr/dlrm/paddlepaddle/trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..5ea958e4e1919ab40029fdc27e77616ea048c781 --- /dev/null +++ b/recommendation/ctr/dlrm/paddlepaddle/trainer.py @@ -0,0 +1,231 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import os +import paddle.nn as nn +import time +import logging +import sys +import importlib + +__dir__ = os.path.dirname(os.path.abspath(__file__)) +#sys.path.append(__dir__) +sys.path.append(os.path.abspath(os.path.join(__dir__, '..'))) + +from utils.utils_single import load_yaml, load_dy_model_class, get_abs_model, create_data_loader +from utils.save_load import load_model, save_model +from paddle.io import DistributedBatchSampler, DataLoader +import argparse + +logging.basicConfig( + format='%(asctime)s - %(levelname)s - %(message)s', level=logging.INFO) +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + + +def parse_args(): + parser = argparse.ArgumentParser(description='paddle-rec run') + parser.add_argument("-m", "--config_yaml", type=str) + parser.add_argument("-o", "--opt", nargs='*', type=str) + args = parser.parse_args() + args.abs_dir = os.path.dirname(os.path.abspath(args.config_yaml)) + args.config_yaml = get_abs_model(args.config_yaml) + return args + + +def main(args): + # load config + config = load_yaml(args.config_yaml) + dy_model_class = load_dy_model_class(args.abs_dir) + config["config_abs_dir"] = args.abs_dir + # modify config from command + if args.opt: + for parameter in args.opt: + parameter = parameter.strip() + key, value = parameter.split("=") + if type(config.get(key)) is int: + value = int(value) + if type(config.get(key)) is float: + value = float(value) + if type(config.get(key)) is bool: + value = (True if value.lower() == "true" else False) + config[key] = value + + # tools.vars + use_gpu = config.get("runner.use_gpu", True) + use_auc = config.get("runner.use_auc", False) + use_npu = config.get("runner.use_npu", False) + use_xpu = config.get("runner.use_xpu", False) + use_visual = config.get("runner.use_visual", False) + train_data_dir = config.get("runner.train_data_dir", None) + epochs = config.get("runner.epochs", None) + print_interval = config.get("runner.print_interval", None) + train_batch_size = config.get("runner.train_batch_size", None) + model_save_path = config.get("runner.model_save_path", "model_output") + model_init_path = config.get("runner.model_init_path", None) + use_fleet = config.get("runner.use_fleet", False) + seed = config.get("runner.seed", 12345) + paddle.seed(seed) + + logger.info("**************common.configs**********") + logger.info( + "use_gpu: {}, use_xpu: {}, use_npu: {}, use_visual: {}, train_batch_size: {}, train_data_dir: {}, epochs: {}, print_interval: {}, model_save_path: {}, use_fleet: {}". + format(use_gpu, use_xpu, use_npu, use_visual, train_batch_size, + train_data_dir, epochs, print_interval, model_save_path, use_fleet)) + logger.info("**************common.configs**********") + + if use_xpu: + xpu_device = 'xpu:{0}'.format(os.getenv('FLAGS_selected_xpus', 0)) + place = paddle.set_device(xpu_device) + elif use_npu: + npu_device = 'npu:{0}'.format(os.getenv('FLAGS_selected_npus', 0)) + place = paddle.set_device(npu_device) + else: + place = paddle.set_device('gpu' if use_gpu else 'cpu') + + dy_model = dy_model_class.create_model(config) + + # Create a log_visual object and store the data in the path + if use_visual: + from visualdl import LogWriter + log_visual = LogWriter(args.abs_dir + "/visualDL_log/train") + + if model_init_path is not None: + load_model(model_init_path, dy_model) + + # to do : add optimizer function + optimizer = dy_model_class.create_optimizer(dy_model, config) + + # use fleet run collective + if use_fleet: + from paddle.distributed import fleet + strategy = fleet.DistributedStrategy() + fleet.init(is_collective=True, strategy=strategy) + optimizer = fleet.distributed_optimizer(optimizer) + dy_model = fleet.distributed_model(dy_model) + + logger.info("read data") + train_dataloader = create_data_loader(config=config, place=place) + + last_epoch_id = config.get("last_epoch", -1) + step_num = 0 + + for epoch_id in range(last_epoch_id + 1, epochs): + # set train mode + dy_model.train() + metric_list, metric_list_name = dy_model_class.create_metrics() + #auc_metric = paddle.metric.Auc("ROC") + epoch_begin = time.time() + interval_begin = time.time() + train_reader_cost = 0.0 + train_run_cost = 0.0 + total_samples = 0 + reader_start = time.time() + + #we will drop the last incomplete batch when dataset size is not divisible by the batch size + assert any(train_dataloader( + )), "train_dataloader is null, please ensure batch size < dataset size!" + + for batch_id, batch in enumerate(train_dataloader()): + train_reader_cost += time.time() - reader_start + optimizer.clear_grad() + train_start = time.time() + batch_size = len(batch[0]) + + loss, metric_list, tensor_print_dict = dy_model_class.train_forward( + dy_model, metric_list, batch, config) + + loss.backward() + optimizer.step() + train_run_cost += time.time() - train_start + + if not use_fleet: + total_samples += batch_size + else: + total_samples += batch_size * paddle.distributed.get_world_size() + + if batch_id % print_interval == 0: + metric_str = "" + for metric_id in range(len(metric_list_name)): + metric_str += ( + metric_list_name[metric_id] + + ":{:.6f}, ".format(metric_list[metric_id].accumulate()) + ) + if use_visual: + log_visual.add_scalar( + tag="train/" + metric_list_name[metric_id], + step=step_num, + value=metric_list[metric_id].accumulate()) + tensor_print_str = "" + if tensor_print_dict is not None: + for var_name, var in tensor_print_dict.items(): + tensor_print_str += ( + "{}:".format(var_name) + + str(var.numpy()).strip("[]") + ",") + if use_visual: + log_visual.add_scalar( + tag="train/" + var_name, + step=step_num, + value=var.numpy()) + logger.info( + "epoch: {}, batch_id: {}, ".format( + epoch_id, batch_id) + metric_str + tensor_print_str + + " avg_reader_cost: {:.5f} sec, avg_batch_cost: {:.5f} sec, avg_samples: {:.5f}, ips: {:.5f} ins/s". + format(train_reader_cost / print_interval, ( + train_reader_cost + train_run_cost) / print_interval, + total_samples / print_interval, total_samples / ( + train_reader_cost + train_run_cost + 0.0001))) + train_reader_cost = 0.0 + train_run_cost = 0.0 + total_samples = 0 + reader_start = time.time() + step_num = step_num + 1 + + metric_str = "" + for metric_id in range(len(metric_list_name)): + metric_str += ( + metric_list_name[metric_id] + + ": {:.6f},".format(metric_list[metric_id].accumulate())) + if use_auc: + metric_list[metric_id].reset() + + tensor_print_str = "" + if tensor_print_dict is not None: + for var_name, var in tensor_print_dict.items(): + tensor_print_str += ( + "{}:".format(var_name) + str(var.numpy()).strip("[]") + "," + ) + + logger.info("epoch: {} done, ".format(epoch_id) + metric_str + + tensor_print_str + " epoch time: {:.2f} s".format( + time.time() - epoch_begin)) + + if use_fleet: + trainer_id = paddle.distributed.get_rank() + if trainer_id == 0: + save_model( + dy_model, + optimizer, + model_save_path, + epoch_id, + prefix='rec') + else: + save_model( + dy_model, optimizer, model_save_path, epoch_id, prefix='rec') + + +if __name__ == '__main__': + args = parse_args() + main(args)