diff --git a/recommendation/ctr/dlrm/pytorch/dlrm/dist_model.py b/recommendation/ctr/dlrm/pytorch/dlrm/dist_model.py index 467fb9ba21990af68f642343f78efff6b33d2eaa..de79a510a122621bdb25ee69c2d4559688fe911b 100644 --- a/recommendation/ctr/dlrm/pytorch/dlrm/dist_model.py +++ b/recommendation/ctr/dlrm/pytorch/dlrm/dist_model.py @@ -1,3 +1,19 @@ +# Copyright (c) 2022, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + + """Distributed version of DLRM model In order to code the hybrid decomposition, the model code needs to be restructured. I don't know a clean enough @@ -427,3 +443,10 @@ class DistDlrm(): def to(self, *args, **kwargs): self.bottom_model.to(*args, **kwargs) self.top_model.to(*args, **kwargs) + + def state_dict(self): + dlrm_state_dic = {} + dlrm_state_dic.update(self.bottom_model.state_dict()) + dlrm_state_dic.update(self.top_model.state_dict()) + + return dlrm_state_dic diff --git a/recommendation/ctr/dlrm/pytorch/scripts/train.py b/recommendation/ctr/dlrm/pytorch/scripts/train.py index 24eed49c7ff12f28d9acfb272929aeaf8d30481f..3a57660fb3aa05f16466d441885f31a424eed1ed 100644 --- a/recommendation/ctr/dlrm/pytorch/scripts/train.py +++ b/recommendation/ctr/dlrm/pytorch/scripts/train.py @@ -1,3 +1,18 @@ +# Copyright (c) 2022, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + """Reference training script Only Criteo data is supported at the moment, one hot embedding. """ @@ -53,7 +68,7 @@ flags.DEFINE_enum("dataset_type", "memmap", ["bin", "memmap", "dist"], "Which da flags.DEFINE_boolean("use_embedding_ext", True, "Use embedding cuda extension. If False, use Pytorch embedding") # Saving and logging flags -flags.DEFINE_string("output_dir", "/tmp", "path where to save") +flags.DEFINE_string("output_dir", ".", "path where to save") flags.DEFINE_integer("test_freq", None, "#steps test. If None, 20 tests per epoch per MLperf rule.") flags.DEFINE_float("test_after", 0, "Don't test the model unless this many epochs has been completed") flags.DEFINE_integer("print_freq", None, "#steps per pring")