diff --git a/nlp/dialogue_generation/cpm/pytorch/base/run_pretraining.py b/nlp/dialogue_generation/cpm/pytorch/base/run_pretraining.py index 9770a23145a494ea1aac3e8d732f0f6024ddad0c..c05eb1a4d5ea185ca36c7bb1619a8b6ff4087f73 100755 --- a/nlp/dialogue_generation/cpm/pytorch/base/run_pretraining.py +++ b/nlp/dialogue_generation/cpm/pytorch/base/run_pretraining.py @@ -105,6 +105,9 @@ def main(): training_event.on_train_end() raw_train_end_time = logger.previous_log_time training_state.raw_train_time = (raw_train_end_time - raw_train_start_time) / 1e+3 + + trainer.save_checkpoint() + return config, training_state if __name__ == "__main__": diff --git a/nlp/dialogue_generation/cpm/pytorch/base/train/trainer.py b/nlp/dialogue_generation/cpm/pytorch/base/train/trainer.py index cf201efe904d9cf5ec426c962f60756e6c0440df..20139a72ce590a6419daab5f6b11c74fadb9e8cb 100755 --- a/nlp/dialogue_generation/cpm/pytorch/base/train/trainer.py +++ b/nlp/dialogue_generation/cpm/pytorch/base/train/trainer.py @@ -150,3 +150,8 @@ class Trainer(): ]) return do_eval or state.global_steps >= self.config.max_steps + + def save_checkpoint(self): + if self.config.n_gpu == 1 or (self.config.n_gpu > 1 and self.config.device == 0): + print("save checkpoint...") + torch.save(self.model.module.state_dict(), "cpm_model_states_medium_end2end.pt") \ No newline at end of file diff --git a/recommendation/ctr/dlrm/pytorch/dlrm/dist_model.py b/recommendation/ctr/dlrm/pytorch/dlrm/dist_model.py index 467fb9ba21990af68f642343f78efff6b33d2eaa..083257d8d97ff65ceddd308c691ecb3e718b63cd 100644 --- a/recommendation/ctr/dlrm/pytorch/dlrm/dist_model.py +++ b/recommendation/ctr/dlrm/pytorch/dlrm/dist_model.py @@ -1,3 +1,20 @@ + +# Copyright (c) 2022, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + + """Distributed version of DLRM model In order to code the hybrid decomposition, the model code needs to be restructured. I don't know a clean enough @@ -427,3 +444,10 @@ class DistDlrm(): def to(self, *args, **kwargs): self.bottom_model.to(*args, **kwargs) self.top_model.to(*args, **kwargs) + + def state_dict(self): + dlrm_state_dic = {} + dlrm_state_dic.update(self.bottom_model.state_dict()) + dlrm_state_dic.update(self.top_model.state_dict()) + + return dlrm_state_dic diff --git a/recommendation/ctr/dlrm/pytorch/scripts/train.py b/recommendation/ctr/dlrm/pytorch/scripts/train.py index 24eed49c7ff12f28d9acfb272929aeaf8d30481f..3a57660fb3aa05f16466d441885f31a424eed1ed 100644 --- a/recommendation/ctr/dlrm/pytorch/scripts/train.py +++ b/recommendation/ctr/dlrm/pytorch/scripts/train.py @@ -1,3 +1,18 @@ +# Copyright (c) 2022, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + """Reference training script Only Criteo data is supported at the moment, one hot embedding. """ @@ -53,7 +68,7 @@ flags.DEFINE_enum("dataset_type", "memmap", ["bin", "memmap", "dist"], "Which da flags.DEFINE_boolean("use_embedding_ext", True, "Use embedding cuda extension. If False, use Pytorch embedding") # Saving and logging flags -flags.DEFINE_string("output_dir", "/tmp", "path where to save") +flags.DEFINE_string("output_dir", ".", "path where to save") flags.DEFINE_integer("test_freq", None, "#steps test. If None, 20 tests per epoch per MLperf rule.") flags.DEFINE_float("test_after", 0, "Don't test the model unless this many epochs has been completed") flags.DEFINE_integer("print_freq", None, "#steps per pring")