From f04ec370f9115e98a091f2a5126c6638c0294698 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E5=87=AF=E5=AE=87?= Date: Wed, 30 Apr 2025 15:00:10 +0800 Subject: [PATCH] [built-in][PyTorch][OpenRLHF] upload OpenRLHF original code --- .../.github/workflows/python-package.yml | 58 ++ .../rl/OpenRLHF_v0.5.7_for_PyTorch/.gitignore | 151 ++++ .../.pre-commit-config.yaml | 38 + .../CONTRIBUTING.md | 7 + .../rl/OpenRLHF_v0.5.7_for_PyTorch/LICENSE | 201 +++++ .../rl/OpenRLHF_v0.5.7_for_PyTorch/README.md | 477 +++++++++++ .../OpenRLHF_v0.5.7_for_PyTorch/README_ja.md | 458 +++++++++++ .../OpenRLHF_v0.5.7_for_PyTorch/README_zh.md | 482 +++++++++++ .../dockerfile/Dockerfile | 23 + .../dockerfile/docker-entrypoint.sh | 25 + .../docs/ppo_examples.md | 69 ++ .../examples/scripts/docker_run.sh | 8 + .../examples/scripts/nvidia_docker_install.sh | 28 + .../examples/scripts/serve_remote_rm.sh | 10 + .../scripts/train_conditional_llama.sh | 50 ++ .../scripts/train_continue_pretrain_llama.sh | 29 + .../examples/scripts/train_dpo_llama.sh | 36 + .../examples/scripts/train_dpo_llama_34b.sh | 33 + .../examples/scripts/train_dpo_ring_llama.sh | 41 + .../scripts/train_iterative_dpo_llama.sh | 99 +++ .../scripts/train_knowledge_distillation.sh | 32 + .../examples/scripts/train_kto_llama.sh | 31 + .../examples/scripts/train_llama_slurm.sh | 42 + .../examples/scripts/train_ppo_llama.sh | 40 + .../examples/scripts/train_ppo_llama_ray.sh | 49 ++ .../scripts/train_ppo_llama_ray_70b.sh | 43 + .../scripts/train_ppo_llama_ray_slurm.sh | 98 +++ .../scripts/train_ppo_llama_with_remote_rm.sh | 50 ++ .../examples/scripts/train_prm_mistral.sh | 34 + .../scripts/train_reinforce_llama_ray.sh | 43 + .../scripts/train_rejection_sampling_llama.sh | 99 +++ .../examples/scripts/train_rm_llama.sh | 31 + .../examples/scripts/train_sft_llama.sh | 30 + .../scripts/train_sft_mixtral_lora.sh | 30 + .../openrlhf/__init__.py | 0 .../openrlhf/cli/__init__.py | 0 .../openrlhf/cli/batch_inference.py | 365 +++++++++ .../openrlhf/cli/interactive_chat.py | 128 +++ .../openrlhf/cli/lora_combiner.py | 45 ++ .../openrlhf/cli/serve_rm.py | 119 +++ .../openrlhf/cli/train_dpo.py | 278 +++++++ .../openrlhf/cli/train_kd.py | 232 ++++++ .../openrlhf/cli/train_kto.py | 226 ++++++ .../openrlhf/cli/train_ppo.py | 443 ++++++++++ .../openrlhf/cli/train_ppo_ray.py | 400 +++++++++ .../openrlhf/cli/train_prm.py | 204 +++++ .../openrlhf/cli/train_rm.py | 261 ++++++ .../openrlhf/cli/train_sft.py | 258 ++++++ .../openrlhf/datasets/__init__.py | 7 + .../datasets/process_reward_dataset.py | 136 ++++ .../openrlhf/datasets/prompts_dataset.py | 57 ++ .../openrlhf/datasets/reward_dataset.py | 235 ++++++ .../openrlhf/datasets/sft_dataset.py | 231 ++++++ .../datasets/unpaired_preference_dataset.py | 143 ++++ .../openrlhf/datasets/utils.py | 17 + .../openrlhf/models/__init__.py | 29 + .../openrlhf/models/actor.py | 242 ++++++ .../openrlhf/models/loss.py | 326 ++++++++ .../openrlhf/models/model.py | 304 +++++++ .../openrlhf/models/ring_attn_utils.py | 74 ++ .../openrlhf/models/utils.py | 120 +++ .../openrlhf/trainer/__init__.py | 17 + .../openrlhf/trainer/dpo_trainer.py | 478 +++++++++++ .../openrlhf/trainer/kd_trainer.py | 249 ++++++ .../openrlhf/trainer/kto_trainer.py | 342 ++++++++ .../openrlhf/trainer/ppo_trainer.py | 523 ++++++++++++ .../openrlhf/trainer/ppo_utils/__init__.py | 12 + .../trainer/ppo_utils/experience_maker.py | 765 ++++++++++++++++++ .../trainer/ppo_utils/kl_controller.py | 29 + .../trainer/ppo_utils/replay_buffer.py | 237 ++++++ .../openrlhf/trainer/prm_trainer.py | 249 ++++++ .../openrlhf/trainer/ray/__init__.py | 14 + .../openrlhf/trainer/ray/launcher.py | 316 ++++++++ .../openrlhf/trainer/ray/ppo_actor.py | 448 ++++++++++ .../openrlhf/trainer/ray/ppo_critic.py | 197 +++++ .../openrlhf/trainer/ray/utils.py | 22 + .../openrlhf/trainer/ray/vllm_engine.py | 149 ++++ .../openrlhf/trainer/ray/vllm_worker_wrap.py | 58 ++ .../openrlhf/trainer/rm_trainer.py | 370 +++++++++ .../openrlhf/trainer/sft_trainer.py | 318 ++++++++ .../openrlhf/utils/__init__.py | 10 + .../openrlhf/utils/deepspeed/__init__.py | 5 + .../openrlhf/utils/deepspeed/deepspeed.py | 467 +++++++++++ .../utils/deepspeed/deepspeed_utils.py | 107 +++ .../openrlhf/utils/distributed_sampler.py | 151 ++++ .../openrlhf/utils/distributed_util.py | 72 ++ .../openrlhf/utils/logging_utils.py | 56 ++ .../openrlhf/utils/processor.py | 103 +++ .../openrlhf/utils/remote_rm_utils.py | 52 ++ .../openrlhf/utils/utils.py | 125 +++ .../pyproject.toml | 56 ++ .../requirements.txt | 21 + .../rl/OpenRLHF_v0.5.7_for_PyTorch/setup.py | 89 ++ .../OpenRLHF_v0.5.7_for_PyTorch/version.txt | 1 + 94 files changed, 13963 insertions(+) create mode 100644 PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/.github/workflows/python-package.yml create mode 100644 PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/.gitignore create mode 100644 PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/.pre-commit-config.yaml create mode 100644 PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/CONTRIBUTING.md create mode 100644 PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/LICENSE create mode 100644 PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/README.md create mode 100644 PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/README_ja.md create mode 100644 PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/README_zh.md create mode 100644 PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/dockerfile/Dockerfile create mode 100644 PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/dockerfile/docker-entrypoint.sh create mode 100644 PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/docs/ppo_examples.md create mode 100644 PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/examples/scripts/docker_run.sh create mode 100644 PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/examples/scripts/nvidia_docker_install.sh create mode 100644 PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/examples/scripts/serve_remote_rm.sh create mode 100644 PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/examples/scripts/train_conditional_llama.sh create mode 100644 PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/examples/scripts/train_continue_pretrain_llama.sh create mode 100644 PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/examples/scripts/train_dpo_llama.sh create mode 100644 PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/examples/scripts/train_dpo_llama_34b.sh create mode 100644 PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/examples/scripts/train_dpo_ring_llama.sh create mode 100644 PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/examples/scripts/train_iterative_dpo_llama.sh create mode 100644 PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/examples/scripts/train_knowledge_distillation.sh create mode 100644 PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/examples/scripts/train_kto_llama.sh create mode 100644 PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/examples/scripts/train_llama_slurm.sh create mode 100644 PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/examples/scripts/train_ppo_llama.sh create mode 100644 PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/examples/scripts/train_ppo_llama_ray.sh create mode 100644 PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/examples/scripts/train_ppo_llama_ray_70b.sh create mode 100644 PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/examples/scripts/train_ppo_llama_ray_slurm.sh create mode 100644 PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/examples/scripts/train_ppo_llama_with_remote_rm.sh create mode 100644 PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/examples/scripts/train_prm_mistral.sh create mode 100644 PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/examples/scripts/train_reinforce_llama_ray.sh create mode 100644 PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/examples/scripts/train_rejection_sampling_llama.sh create mode 100644 PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/examples/scripts/train_rm_llama.sh create mode 100644 PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/examples/scripts/train_sft_llama.sh create mode 100644 PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/examples/scripts/train_sft_mixtral_lora.sh create mode 100644 PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/__init__.py create mode 100644 PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/cli/__init__.py create mode 100644 PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/cli/batch_inference.py create mode 100644 PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/cli/interactive_chat.py create mode 100644 PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/cli/lora_combiner.py create mode 100644 PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/cli/serve_rm.py create mode 100644 PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/cli/train_dpo.py create mode 100644 PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/cli/train_kd.py create mode 100644 PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/cli/train_kto.py create mode 100644 PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/cli/train_ppo.py create mode 100644 PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/cli/train_ppo_ray.py create mode 100644 PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/cli/train_prm.py create mode 100644 PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/cli/train_rm.py create mode 100644 PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/cli/train_sft.py create mode 100644 PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/datasets/__init__.py create mode 100644 PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/datasets/process_reward_dataset.py create mode 100644 PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/datasets/prompts_dataset.py create mode 100644 PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/datasets/reward_dataset.py create mode 100644 PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/datasets/sft_dataset.py create mode 100644 PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/datasets/unpaired_preference_dataset.py create mode 100644 PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/datasets/utils.py create mode 100644 PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/models/__init__.py create mode 100644 PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/models/actor.py create mode 100644 PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/models/loss.py create mode 100644 PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/models/model.py create mode 100644 PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/models/ring_attn_utils.py create mode 100644 PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/models/utils.py create mode 100644 PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/trainer/__init__.py create mode 100644 PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/trainer/dpo_trainer.py create mode 100644 PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/trainer/kd_trainer.py create mode 100644 PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/trainer/kto_trainer.py create mode 100644 PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/trainer/ppo_trainer.py create mode 100644 PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/trainer/ppo_utils/__init__.py create mode 100644 PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/trainer/ppo_utils/experience_maker.py create mode 100644 PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/trainer/ppo_utils/kl_controller.py create mode 100644 PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/trainer/ppo_utils/replay_buffer.py create mode 100644 PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/trainer/prm_trainer.py create mode 100644 PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/trainer/ray/__init__.py create mode 100644 PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/trainer/ray/launcher.py create mode 100644 PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/trainer/ray/ppo_actor.py create mode 100644 PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/trainer/ray/ppo_critic.py create mode 100644 PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/trainer/ray/utils.py create mode 100644 PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/trainer/ray/vllm_engine.py create mode 100644 PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/trainer/ray/vllm_worker_wrap.py create mode 100644 PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/trainer/rm_trainer.py create mode 100644 PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/trainer/sft_trainer.py create mode 100644 PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/utils/__init__.py create mode 100644 PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/utils/deepspeed/__init__.py create mode 100644 PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/utils/deepspeed/deepspeed.py create mode 100644 PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/utils/deepspeed/deepspeed_utils.py create mode 100644 PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/utils/distributed_sampler.py create mode 100644 PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/utils/distributed_util.py create mode 100644 PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/utils/logging_utils.py create mode 100644 PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/utils/processor.py create mode 100644 PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/utils/remote_rm_utils.py create mode 100644 PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/utils/utils.py create mode 100644 PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/pyproject.toml create mode 100644 PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/requirements.txt create mode 100644 PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/setup.py create mode 100644 PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/version.txt diff --git a/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/.github/workflows/python-package.yml b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/.github/workflows/python-package.yml new file mode 100644 index 0000000000..4fac7cb7f6 --- /dev/null +++ b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/.github/workflows/python-package.yml @@ -0,0 +1,58 @@ +name: Python package + +on: + release: + types: [published] + workflow_dispatch: + +jobs: + build-and-publish: + # do not run in forks + if: ${{ github.repository_owner == 'OpenRLHF' && (github.event_name == 'release' || github.event_name == 'workflow_dispatch') }} + name: build wheel and upload + runs-on: ubuntu-22.04 + + strategy: + matrix: + python-version: [3.10.14, 3.11.0] + cuda-version: [12.1] + + steps: + - uses: actions/checkout@v3 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + + - name: Install CUDA ${{ matrix.cuda-version }} + run: | + wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/cuda-ubuntu2204.pin + sudo mv cuda-ubuntu2204.pin /etc/apt/preferences.d/cuda-repository-pin-600 + sudo apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/3bf863cc.pub + sudo add-apt-repository "deb https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/ /" + sudo apt-get update + sudo apt-get -y install cuda-${{ matrix.cuda-version }} + + - name: Set up CUDA environment variables + run: | + echo "/usr/local/cuda-${{ matrix.cuda-version }}/lib64" | sudo tee -a /etc/ld.so.conf.d/cuda.conf + echo "export PATH=/usr/local/cuda-${{ matrix.cuda-version }}/bin:\$PATH" | sudo tee -a /etc/environment + sudo ldconfig + shell: bash + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install setuptools wheel twine packaging + + - name: Build package + run: | + python setup.py bdist_wheel --dist-dir=dist + + - name: Publish package + env: + TWINE_USERNAME: __token__ + TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }} + run: | + python -m twine upload dist/* \ No newline at end of file diff --git a/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/.gitignore b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/.gitignore new file mode 100644 index 0000000000..e9b9f0e6c0 --- /dev/null +++ b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/.gitignore @@ -0,0 +1,151 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +pip-wheel-metadata/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ +docs/.build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +.python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# IDE +.idea/ +.vscode/ + +# macos +*.DS_Store +#data/ + +docs/.build + +# pytorch checkpoint +*.pt + +core +*/ckpt/* +.vscode +.nfs* +*jianh* +*test_scripts* +*/checkpoint/* \ No newline at end of file diff --git a/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/.pre-commit-config.yaml b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/.pre-commit-config.yaml new file mode 100644 index 0000000000..fa5e59cff6 --- /dev/null +++ b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/.pre-commit-config.yaml @@ -0,0 +1,38 @@ +default_language_version: + python: python3 + +ci: + autofix_prs: true + autoupdate_commit_msg: '[pre-commit.ci] pre-commit suggestions' + autoupdate_schedule: quarterly + +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.5.0 + hooks: + - id: check-yaml + - id: check-case-conflict + - id: detect-private-key + - id: check-added-large-files + args: ['--maxkb=1000'] + - id: requirements-txt-fixer + + - repo: https://github.com/PyCQA/autoflake + rev: v2.0.2 + hooks: + - id: autoflake + args: [--remove-all-unused-imports, --in-place] + + - repo: https://github.com/PyCQA/isort + rev: 5.13.2 + hooks: + - id: isort + name: Format imports + exclude: docs/ + + - repo: https://github.com/psf/black + rev: 24.3.0 + hooks: + - id: black + name: Format code + additional_dependencies: ['click==8.0.2'] diff --git a/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/CONTRIBUTING.md b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/CONTRIBUTING.md new file mode 100644 index 0000000000..e93c04b6db --- /dev/null +++ b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/CONTRIBUTING.md @@ -0,0 +1,7 @@ +# Contributing to OpenRLHF + +After cloning the repository, please install pre-commit hooks with: +``` +pip install pre-commit +pre-commit install +``` \ No newline at end of file diff --git a/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/LICENSE b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/LICENSE new file mode 100644 index 0000000000..f49a4e16e6 --- /dev/null +++ b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. \ No newline at end of file diff --git a/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/README.md b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/README.md new file mode 100644 index 0000000000..2cb8191853 --- /dev/null +++ b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/README.md @@ -0,0 +1,477 @@ +
+ OpenRLHF logo +
+
+

+ + GitHub Contributors + + + Issues + + + Issues + + + GitHub pull requests + + GitHub stars + +
+ Open-source / Comprehensive / Lightweight / Easy-to-use +

+

+
+ +
+ +[ English | 中文 | 日本語 ] + +OpenRLHF is a high-performance RLHF framework built on Ray, DeepSpeed and HF Transformers: + +- **Simple and easy to use**: OpenRLHF is one of the simplest high-performance RLHF libraries currently available, and seamlessly compatible with Huggingface models and datasets. +- **High performance**: RLHF training spends 80% of the time on the sample generation stage. Thanks to the ability to use a large inference batch size with Ray and Packing Samples and vLLM generation acceleration, the performance of OpenRLHF 3~4x+ that of Optimized DeepSpeedChat with Hybrid Engine. +- **Distributed RLHF**: OpenRLHF distribute the Actor, Reward, Reference, and Critic models onto separate GPUs using Ray, while placing the Adam optimizer on the CPU. This enables full-scale fine-tuning of 70B+ models with multiple A100 80G GPUs and vLLM and 7B models across multiple 24GB RTX 4090 GPUs. +- **PPO Implementation Optimization**: We integrated the implementation tricks for PPO to improve the training stability, referencing [Zhihu](https://zhuanlan.zhihu.com/p/622134699) and the [Notion blog](https://hijkzzz.notion.site/rlhf-implementation-tricks?v=158d9a33ecc98132bf9e000c39227361). + +More details are in [Slides](https://docs.google.com/presentation/d/1JRhB1d7csofx0PIZBmfyBdMluxNd5JLPpUHrrvVhGnk/edit?usp=sharing) | [Technical Report](https://arxiv.org/abs/2405.11143) | [Documents](https://openrlhf.readthedocs.io/) + +## News +- [2025/1] HKUST reproduced the [DeepSeek-R1-Zero and DeepSeek-R1 training on small models using OpenRLHF](https://github.com/hkust-nlp/simpleRL-reason) +- [2024/12] We "proposed" 😊 the [REINFORCE++: A Simple and Efficient Approach for Aligning Large Language Models](https://www.researchgate.net/publication/387487679_REINFORCE_A_SIMPLE_AND_EFFICIENT_APPROACH_FOR_ALIGNING_LARGE_LANGUAGE_MODELS). +- [2024/12] We analyzed the PPO, REINFORCE++, GRPO and RLOO in the [Notion Blogpost](https://hijkzzz.notion.site/unraveling-rlhf-and-its-variants-engineering-insights#147d9a33ecc9806090f3d5c749d31f05). + + +## Features + +- Distributed [PPO](./examples/scripts/train_ppo_llama_ray.sh) and [REINFORCE++/RLOO](./examples/scripts/train_reinforce_llama_ray.sh) implementations based on Ray. +- Full RLHF fine-tuning support for models with [over 70 billion parameters](./examples/scripts/train_ppo_llama_ray_70b.sh). +- Integration with vLLM for accelerated generation in RLHF tasks (`--vllm_num_engines`). +- Support for multiple reward models (`--reward_pretrain model1,model2...`) and remote reward models (`--remote_rm_url`). +- Implementation of [DPO (Direct Preference Optimization)/IPO/cDPO](./examples/scripts/train_dpo_llama.sh) and [Kahneman-Tversky Optimization (KTO)](./examples/scripts/train_kto_llama.sh). +- Support for [Iterative DPO](./examples/scripts/train_iterative_dpo_llama.sh) ([GitHub: Online-RLHF](https://github.com/RLHFlow/Online-RLHF)). +- Support for [Rejection Sampling](./examples/scripts/train_rejection_sampling_llama.sh). +- Implementation of [Conditional SFT](./examples/scripts/train_conditional_llama.sh) ([arXiv:2308.12050](https://arxiv.org/abs/2308.12050)). +- Support for [Knowledge Distillation](./examples/scripts/train_knowledge_distillation.sh) ([Microsoft: minillm](https://github.com/microsoft/LMOps/tree/main/minillm)). +- Integration of [Process Reward Model (PRM)](./examples/scripts/train_prm_mistral.sh). +- Packing of training samples for SFT, DPO, RM, PRM, and PPO (`--packing_samples`). +- Implementation of [RingAttention](./examples/scripts/train_dpo_ring_llama.sh) (`--ring_attn_size`, `--ring_head_stride`). +- Support for [Mixture of Experts (MoE)](./examples/test_scripts/train_sft_mixtral_lora.sh) (`--aux_loss_coef`). +- Integration of FlashAttention2 (`--flash_attn`). +- Support for QLoRA (`--load_in_4bit`) and [LoRA](./examples/scripts/train_sft_mixtral_lora.sh) (`--lora_rank`, `--target_modules`). +- Compatibility with HuggingFace's `tokenizer.apply_chat_template` for datasets (`--apply_chat_template` and `--input_key`). +- Logging support with Wandb (`--use_wandb`) and TensorBoard (`--use_tensorboard`). +- Checkpoint recovery functionality (`--load_checkpoint` and `--save_steps`). +- Provided multi-node training scripts, such as [DPO](./examples/scripts/train_llama_slurm.sh) and [Ray PPO](./examples/scripts/train_ppo_llama_ray_slurm.sh). + + +### PPO Support Matrix + +| Feature | OpenRLHF | DSChat | CAIChat | TRL | +| ------------- |:-------------:| :-------------:| :-------------:| :-------------:| +| 70B+ Full Tuning with 16 A100-80GB | ✅ | ❌ | ❌ | ❌ | +| 7B Full Tuning with 4 RTX4090 | ✅ | ❌ | ❌ | ❌ | +| 34B DPO Full Tuning with 8 A100-80GB | ✅ | ❌ | ❌ | ❌ | +| Inference Engine in PPO | ✅ | ✅ | ❌ | ❌ | +| PPO Implementation Tricks | ✅ | ❌ | ❌ | ✅ | +| Support QLoRA | ✅ | ❌ | ❌ | ✅ | +| Support Mixtral 8*7b | ✅ | ❌ | ❌ | ❌ | +| Support Unmerged Actor-Critic | ✅ | ✅ | ✅ | ❌ | +| Support Multiple Reward Models | ✅ | ❌ | ❌ | ❌ | +| Support Huggingface Models | ✅ | ✅ | ✅ | ✅ | +| Easy-to-use | ✅ | ❌ (HybridEngine bugs) | ✅ | ✅ | + + +## Quick Start + +### Installation + +To use OpenRLHF, first launch the docker container (**Recommended**) and `pip install` openrlhf inside the docker container: + +```bash +# Launch the docker container +docker run --runtime=nvidia -it --rm --shm-size="10g" --cap-add=SYS_ADMIN -v $PWD:/openrlhf nvcr.io/nvidia/pytorch:24.07-py3 bash +sudo pip uninstall xgboost transformer_engine flash_attn -y + +# pip install +pip install openrlhf + +# If you want to use vLLM acceleration (Install vLLM 0.6.5) +pip install openrlhf[vllm] +# latest vLLM is also supported +pip install openrlhf[vllm_latest] + +# pip install the latest version +pip install git+https://github.com/OpenRLHF/OpenRLHF.git + +# Or git clone +git clone https://github.com/OpenRLHF/OpenRLHF.git +cd OpenRLHF +pip install -e . +``` + +> [!NOTE] +>We recommend using vLLM 0.6.4 or higher. Other versions (vLLM >= 0.4.2) may require weight synchronization via Gloo (`--vllm_sync_backend gloo`). +>We also provided the [Dockerfiles for vLLM](./dockerfile/) and [One-Click Installation Script of Nvidia-Docker](./examples/scripts/nvidia_docker_install.sh). + +### Prepare Datasets +OpenRLHF provides multiple data processing methods in our dataset classes. +Such as in the [Prompt Dataset](https://github.com/OpenRLHF/OpenRLHF/blob/main/openrlhf/datasets/prompts_dataset.py#L6): + +```python +def preprocess_data(data, input_template=None, input_key="input", apply_chat_template=None) -> str: + if apply_chat_template: + chat = data[input_key] + if isinstance(chat, str): + chat = [{"role": "user", "content": chat}] + prompt = apply_chat_template(chat, tokenize=False, add_generation_prompt=True) + else: + prompt = data[input_key] + if input_template: + prompt = input_template.format(prompt) + return prompt +``` + +- We can use `--input_key` to specify the `JSON key name` of the input datasets `--prompt_data {name or path}` (PPO) or `--dataset {name or path}`, and use `--apply_chat_template` to utilize the `chat_template` from the [Huggingface Tokenizer](https://huggingface.co/docs/transformers/main/en/chat_templating). +- If you don't want to use `--apply_chat_template`, you can use `--input_template` instead, or preprocess the datasets offline in advance. +- OpenRLHF also support mixing multiple datasets using `--prompt_data_probs 0.1,0.4,0.5` (PPO) or `--dataset_probs 0.1,0.4,0.5`. + +How Chat Templating Works: + +```python +dataset = [{"input_key": [ + {"role": "user", "content": "Hello, how are you?"}, + {"role": "assistant", "content": "I'm doing great. How can I help you today?"}, + {"role": "user", "content": "I'd like to show off how chat templating works!"}, +]}] + +tokenizer.apply_chat_template(dataset[0]["input_key"], tokenize=False) + +"[INST] Hello, how are you? [/INST]I'm doing great. How can I help you today? [INST] I'd like to show off how chat templating works! [/INST]" +``` + +How to specify training and test datasets ? + +You can specify it using the `data_type@data_dir` format. For example, the dataset can be set as `--dataset json@./data`. + +``` +data +├── test.jsonl +└── train.jsonl +``` + +> [!NOTE] +> By default, we use `train` and `test` as splits to distinguish training and testing datasets from Huggingface. +> The ``JSON key`` options depends on the specific datasets. See [Reward Dataset](https://github.com/OpenRLHF/OpenRLHF/blob/main/openrlhf/datasets/reward_dataset.py#L10) and [SFT Dataset](https://github.com/OpenRLHF/OpenRLHF/blob/main/openrlhf/datasets/sft_dataset.py#L9) + +### Supervised Fine-tuning + +OpenRLHF's model checkpoint is fully compatible with HuggingFace models. You can specify the model name or path using `--pretrain {name or path}`, `--reward_pretrain {name or path}` and `--critic_pretrain {name or path}`. We have provided some pre-trained checkpoints and datasets on [HuggingFace OpenRLHF](https://huggingface.co/OpenRLHF). + +Then you can use the startup scripts we provide in the [examples/scripts](./examples/scripts/) directory, or start the training using the following commands. + +```bash +deepspeed --module openrlhf.cli.train_sft \ + --max_len 4096 \ + --dataset Open-Orca/OpenOrca \ + --input_key question \ + --output_key response \ + --input_template $'User: {}\nAssistant: ' \ + --train_batch_size 256 \ + --micro_train_batch_size 2 \ + --max_samples 500000 \ + --pretrain meta-llama/Meta-Llama-3-8B \ + --save_path ./checkpoint/llama3-8b-sft \ + --save_steps -1 \ + --logging_steps 1 \ + --eval_steps -1 \ + --zero_stage 2 \ + --max_epochs 1 \ + --packing_samples \ + --bf16 \ + --flash_attn \ + --learning_rate 5e-6 \ + --gradient_checkpointing \ + --use_wandb {wandb_token} + +# Support HF tokenizer.apply_chat_template +# --apply_chat_template +# --tokenizer_chat_template {HF Chat Template} + +# Support RingAttention +# pip install ring_flash_attn +# --ring_attn_size 2 \ +# --ring_head_stride 2 \ + +# Multi-turn fine-tuning loss +# --multiturn + +# Can also be used for continued pre-training +# --pretrain_mode +``` + +> [!NOTE] +> OpenRLHF SFT/DPO/RewardModel/PPO trainers support `--packing_samples` [based on `--flash_attn`](https://github.com/MeetKai/functionary/tree/main/functionary/train/packing) + + +### Reward Model Training +```bash +deepspeed --module openrlhf.cli.train_rm \ + --save_path ./checkpoint/llama3-8b-rm \ + --save_steps -1 \ + --logging_steps 1 \ + --eval_steps -1 \ + --train_batch_size 256 \ + --micro_train_batch_size 1 \ + --pretrain OpenRLHF/Llama-3-8b-sft-mixture \ + --bf16 \ + --max_epochs 1 \ + --max_len 8192 \ + --zero_stage 3 \ + --learning_rate 9e-6 \ + --dataset OpenRLHF/preference_dataset_mixture2_and_safe_pku \ + --apply_chat_template \ + --chosen_key chosen \ + --rejected_key rejected \ + --flash_attn \ + --packing_samples \ + --gradient_checkpointing \ + --use_wandb {wandb_token} + +``` + +It is recommended to set the `--value_prefix_head` option of the Reward Model to `score`, so that we can load the model using `AutoModelForSequenceClassification`: + +```python +reward_model = AutoModelForSequenceClassification.from_pretrained( + reward_model_path, + num_labels=1, + torch_dtype=torch.bfloat16, + attn_implementation="flash_attention_2", + use_cache=False, + ) +inputs = xxxx (Left Padding Input Tokens) +reward = reward_model.model(*inputs).last_hidden_state +reward = reward_model.score(reward)[:, -1] +``` + +### PPO without Ray + +```bash +deepspeed --module openrlhf.cli.train_ppo \ + --pretrain OpenRLHF/Llama-3-8b-sft-mixture \ + --reward_pretrain OpenRLHF/Llama-3-8b-rm-mixture \ + --save_path ./checkpoint/llama-3-8b-rlhf \ + --save_steps -1 \ + --logging_steps 1 \ + --eval_steps -1 \ + --micro_train_batch_size 2 \ + --train_batch_size 128 \ + --micro_rollout_batch_size 4 \ + --rollout_batch_size 1024 \ + --max_epochs 1 \ + --prompt_max_len 1024 \ + --generate_max_len 1024 \ + --zero_stage 2 \ + --bf16 \ + --actor_learning_rate 5e-7 \ + --critic_learning_rate 9e-6 \ + --init_kl_coef 0.01 \ + --prompt_data OpenRLHF/prompt-collection-v0.1 \ + --input_key context_messages \ + --apply_chat_template \ + --max_samples 100000 \ + --normalize_reward \ + --adam_offload \ + --flash_attn \ + --gradient_checkpointing \ + --use_wandb {wandb_token} + +# Support remote reward model (HTTP) +# --remote_rm_url http://localhost:5000/get_reward +``` + +### PPO/REINFORCE++ with Ray and vLLM + +To improve RLHF training speed or support 70B models, we can use the PPO with Ray and vLLM acceleration + +```bash +# launch the master node of ray in container +ray start --head --node-ip-address 0.0.0.0 --num-gpus 8 + +# if you want to launch ray on more nodes, use +ray start --address {MASTER-NODE-ADDRESS}:6379 --num-gpus 8 + +ray job submit --address="http://127.0.0.1:8265" \ + --runtime-env-json='{"working_dir": "/openrlhf"}' \ + -- python3 -m openrlhf.cli.train_ppo_ray \ + --ref_num_nodes 1 \ + --ref_num_gpus_per_node 2 \ + --reward_num_nodes 1 \ + --reward_num_gpus_per_node 2 \ + --critic_num_nodes 1 \ + --critic_num_gpus_per_node 2 \ + --actor_num_nodes 1 \ + --actor_num_gpus_per_node 2 \ + --vllm_num_engines 2 \ + --vllm_tensor_parallel_size 2 \ + --colocate_critic_reward \ + --colocate_actor_ref \ + --pretrain OpenRLHF/Llama-3-8b-sft-mixture \ + --reward_pretrain OpenRLHF/Llama-3-8b-rm-mixture \ + --save_path /openrlhf/examples/checkpoint/llama3-8b-rlhf \ + --micro_train_batch_size 8 \ + --train_batch_size 128 \ + --micro_rollout_batch_size 16 \ + --rollout_batch_size 1024 \ + --max_samples 100000 \ + --max_epochs 1 \ + --prompt_max_len 1024 \ + --generate_max_len 1024 \ + --zero_stage 3 \ + --bf16 \ + --actor_learning_rate 5e-7 \ + --critic_learning_rate 9e-6 \ + --init_kl_coef 0.01 \ + --prompt_data OpenRLHF/prompt-collection-v0.1 \ + --input_key context_messages \ + --apply_chat_template \ + --normalize_reward \ + --packing_samples \ + --adam_offload \ + --flash_attn \ + --gradient_checkpointing \ + --use_wandb {wandb_token} + +# Support REINFORCE++ | RLOO +# --advantage_estimator reinforce | rloo + +# Support remote reward model (HTTP) +# --remote_rm_url http://localhost:5000/get_reward + +# Support N samples +# --n_samples_per_prompt 4 +``` +> [!NOTE] +> Do not set `--vllm_num_engines` means not using the vLLM engine. +> You can also use ``setup_commands`` to let Ray automatically deploy the environment, such as `--runtime-env-json='{"setup_commands": ["pip install openrlhf[vllm]"]}'`. + +> [!NOTE] +> RLOO in OPENRLHF is a modification based on REINFORCE++, differing from the original version. + +> [!NOTE] +> If you you encounter an error related to index out of range when deepspeed sets up the GPU devices, you can try to set the environment variable [`RAY_EXPERIMENTAL_NOSET_*_VISIBLE_DEVICES`](openrlhf/trainer/ray/utils.py) as a workaround. +> ```bash +> # For NVIDIA GPUs: +> export RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES=1 +> ``` + +The launch scripts and documents for supported algorithms are in [example/scripts](./examples/scripts/) and [Documents - Usage](https://openrlhf.readthedocs.io/en/latest/usage.html) + +### LoRA +If you use `LoRA (Low-Rank Adaptation)`, `OpenRLHF` will not save the full weights by default instead of `LoRA Adapter`. To continue in your task normally, you should combine the `Adapter` with weights of your base model + +```bash +python -m openrlhf.cli.lora_combiner \ + --model_path meta-llama/Meta-Llama-3-8B \ + --lora_path ./checkpoint/llama3-8b-rm \ + --output_path ./checkpoint/llama-3-8b-rm-combined \ + --is_rm \ + --bf16 +``` + +## Performance + +We optimized DSChat's performance to the greatest extent possible by employing techniques such as enabling Adam offload, along with reward model (RM) and reference model (Ref) offload to increase the micro-batch size during the inference stage and avoid out-of-memory issues. We even fixed some bugs in DSChat to enable the Hybrid Engine (HE) for LLaMA2. The average time (seconds) it took to train 1024 prompts with 1 PPO epoch using the Optimized DSChat and OpenRLHF: + +| **Size** | **NVIDIA A800-80GB GPUs** | **Optimized DSChat (with Hybrid Engine)** | **OpenRLHF** | **Speedup** | +| :---: | :---: | :---: | :---: | :---: | +| 7B | 16 | 855.09 | 471.11 | 1.82x | +| 13B | 32 | 1528.93 | 608.93 | 2.5x | +| 34B | 32 | 3634.98 | 1526.4 | 2.4x | +| 70B | 32 | 10407.0 | 4488.53 | 2.3x | + +> [!NOTE] +> The data is outdated; please refer to the performance tuning section for re-testing. + +### Performance Tuning Guide + +To achieve optimal performance, we recommend allocating more nodes to the vLLM Engine. For example, for a 70B model with 32 A100 GPUs, it is advised to allocate 16 A100 GPUs to the vLLM Engine, 8 GPUs to the Actor model, and the remaining 8 GPUs to the Critic model. Additionally, enable the `--colocate_critic_reward`, `--colocate_actor_ref` options to merge nodes. Finally, you should increase the `rollout_micro_batch_size` (and minimize the TP size of vLLM engine) as much as possible. During the training phase, a larger `--micro_train_batch_size` is better and enable `--packing_samples`. When there are enough GPUs, please disable `--adam_offload` and enable `--overlap_comm`. For multi-nodes RLHF, please use `--vllm_sync_backend nccl` with vLLM 0.6.4+. + +## Companies and Organizations using OpenRLHF + +- Google +- ByteDance +- Tencent +- Alibaba +- Baidu +- China Telecom +- Vivo +- Allen AI +- NexusFlow +- Jülich Supercomputing Centre (JSC) +- Berkeley Starling Team +- M-A-P +- ... + +## Join Us + +**How to Join?** + +1. Email us at janhu9527@gmail.com or join [GitHub Organization](https://github.com/OpenRLHF). Please include the following details: + - Your name + - Your GitHub username + - Your areas of interest + - Your skills and experience related to NLP and/or AI +1. You can also join us through the official GitHub [OpenRLHF ↗](https://github.com/OpenRLHF/OpenRLHF) project page. Just create an issue about your interest to contribute and we will get back to you. + +**What can you do?** + +1. Join the team and participate in the development of the OpenRLHF project. +1. Contribute to the project by submitting pull requests. +1. Help improve documentation, fix bugs, or create new features. +1. Share the project and help us grow the community. + +## Sponsor Us + +Your sponsorship can help us maintain and improve OpenRLHF. If you find this project useful, please consider sponsoring us. You can sponsor us on [Open Collective ↗](https://opencollective.com/OpenRLHF). + +## Starchart + +[![Star History Chart](https://api.star-history.com/svg?repos=OpenRLHF/OpenRLHF&type=Date)](https://star-history.com/#OpenRLHF/OpenRLHF&Date) + +## Contributors + +A big thank you to all our contributors! If you want to contribute, feel free to make a pull request or create an issue. + + + + + +## References & Acknowledgements + +We would like to express our gratitude to the following projects and organizations for their contributions to the field of AI and NLP: + +- [Hugging Face Transformers ↗](https://github.com/huggingface/transformers) +- [OpenAI GPT ↗](https://github.com/openai/gpt-3) +- [LLaMA ↗](https://llama.meta.com/) +- [DeepSpeed ↗](https://github.com/microsoft/DeepSpeed) +- [Ray ↗](https://github.com/ray-project/ray) + +Our project would also like to thank [ColossalChat](https://github.com/hpcaitech/ColossalAI/tree/main/applications/Chat) and [DeepSpeedChat](https://github.com/microsoft/DeepSpeedExamples/tree/master/applications/DeepSpeed-Chat). In the early stages of the project, we referred to their code design. + +(2024/7) Our GitHub organization has changed from OpenLLMAI to OpenRLHF. + +## Citation +``` +@article{hu2024openrlhf, + title={OpenRLHF: An Easy-to-use, Scalable and High-performance RLHF Framework}, + author={Jian Hu and Xibin Wu and Zilin Zhu and Xianyu and Weixun Wang and Dehao Zhang and Yu Cao}, + journal={arXiv preprint arXiv:2405.11143}, + year={2024} +} +``` + +______________________________________________________________________ + +*OpenRLHF © 2025 OpenRLHF. All Rights Reserved.* diff --git a/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/README_ja.md b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/README_ja.md new file mode 100644 index 0000000000..2faea1f87e --- /dev/null +++ b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/README_ja.md @@ -0,0 +1,458 @@ +
+ OpenRLHF logo +
+
+

+ + GitHub Contributors + + + Issues + + + Issues + + + GitHub pull requests + + GitHub stars + +
+ オープンソース / 包括的 / 軽量 / 使いやすい +

+

+
+ +
+ +[ English | 中文 | 日本語 ] + +OpenRLHFは、Ray、DeepSpeed、およびHF Transformersを基盤とした高性能なRLHFフレームワークです: + +- **シンプルで使いやすい**: OpenRLHFは現在利用可能な最もシンプルな高性能RLHFライブラリの一つであり、Huggingfaceのモデルとデータセットとシームレスに互換性があります。 +- **高性能**: RLHFトレーニングの80%の時間はサンプル生成段階に費やされます。RayとPacking SamplesおよびvLLM生成加速の能力を活用することで、OpenRLHFのパフォーマンスはOptimized DeepSpeedChat with Hybrid Engineの3〜4倍以上です。 +- **分散RLHF**: OpenRLHFは、Actor、Reward、Reference、およびCriticモデルをRayを使用して別々のGPUに分散し、AdamオプティマイザをCPUに配置します。これにより、複数のA100 80G GPUとvLLMを使用して70B+モデルのフルスケールの微調整が可能になり、複数の24GB RTX 4090 GPUで7Bモデルを微調整できます。 +- **PPO実装の最適化**: トレーニングの安定性を向上させるために、PPOの実装トリックを統合しました。詳細は[Zhihu](https://zhuanlan.zhihu.com/p/622134699)および[Notionブログ](https://hijkzzz.notion.site/rlhf-implementation-tricks?v=158d9a33ecc98132bf9e000c39227361)を参照してください。 + +詳細は[スライド](https://docs.google.com/presentation/d/1JRhB1d7csofx0PIZBmfyBdMluxNd5JLPpUHrrvVhGnk/edit?usp=sharing) | [技術報告](https://arxiv.org/abs/2405.11143) | [ドキュメント](https://openrlhf.readthedocs.io/)をご覧ください。 + +## ニュース +- [2024/12] 私たちは😊 [REINFORCE++: A Simple and Efficient Approach for Aligning Large Language Models](https://www.researchgate.net/publication/387487679_REINFORCE_A_SIMPLE_AND_EFFICIENT_APPROACH_FOR_ALIGNING_LARGE_LANGUAGE_MODELS)を「提案」しました。 +- [2024/12] [Notionブログ](https://hijkzzz.notion.site/unraveling-rlhf-and-its-variants-engineering-insights#147d9a33ecc9806090f3d5c749d31f05)でPPO、REINFORCE++、GRPO、およびRLOOを分析しました。 + +## 特徴 + +- Rayに基づく分散[ PPO](./examples/scripts/train_ppo_llama_ray.sh)および[REINFORCE++/RLOO](./examples/scripts/train_reinforce_llama_ray.sh)の実装。 +- [70億以上のパラメータを持つモデル](./examples/scripts/train_ppo_llama_ray_70b.sh)の完全なRLHF微調整のサポート。 +- RLHFタスクでの生成を加速するためのvLLMの統合(`--vllm_num_engines`)。 +- 複数の報酬モデル(`--reward_pretrain model1,model2...`)およびリモート報酬モデル(`--remote_rm_url`)のサポート。 +- [DPO(直接選好最適化)/IPO/cDPO](./examples/scripts/train_dpo_llama.sh)および[Kahneman-Tversky Optimization(KTO)](./examples/scripts/train_kto_llama.sh)の実装。 +- [反復DPO](./examples/scripts/train_iterative_dpo_llama.sh)([GitHub: Online-RLHF](https://github.com/RLHFlow/Online-RLHF))のサポート。 +- [拒否サンプリング](./examples/scripts/train_rejection_sampling_llama.sh)のサポート。 +- [条件付きSFT](./examples/scripts/train_conditional_llama.sh)([arXiv:2308.12050](https://arxiv.org/abs/2308.12050))の実装。 +- [知識蒸留](./examples/scripts/train_knowledge_distillation.sh)([Microsoft: minillm](https://github.com/microsoft/LMOps/tree/main/minillm))のサポート。 +- [プロセス報酬モデル(PRM)](./examples/scripts/train_prm_mistral.sh)の統合。 +- SFT、DPO、RM、PRM、およびPPOのトレーニングサンプルのパッキング(`--packing_samples`)。 +- [RingAttention](./examples/scripts/train_dpo_ring_llama.sh)の実装(`--ring_attn_size`、`--ring_head_stride`)。 +- [専門家の混合モデル(MoE)](./examples/test_scripts/train_sft_mixtral_lora.sh)のサポート(`--aux_loss_coef`)。 +- FlashAttention2の統合(`--flash_attn`)。 +- QLoRA(`--load_in_4bit`)および[LoRA](./examples/scripts/train_sft_mixtral_lora.sh)(`--lora_rank`、`--target_modules`)のサポート。 +- HuggingFaceの`tokenizer.apply_chat_template`との互換性(`--apply_chat_template`および`--input_key`)。 +- Wandb(`--use_wandb`)およびTensorBoard(`--use_tensorboard`)によるログ記録のサポート。 +- チェックポイントの回復機能(`--load_checkpoint`および`--save_steps`)。 +- [DPO](./examples/scripts/train_llama_slurm.sh)および[Ray PPO](./examples/scripts/train_ppo_llama_ray_slurm.sh)などのマルチノードトレーニングスクリプトを提供。 + +### PPOサポートマトリックス + +| 特徴 | OpenRLHF | DSChat | CAIChat | TRL | +| ------------- |:-------------:| :-------------:| :-------------:| :-------------:| +| 16 A100-80GBで70B+のフルチューニング | ✅ | ❌ | ❌ | ❌ | +| 4 RTX4090で7Bのフルチューニング | ✅ | ❌ | ❌ | ❌ | +| 8 A100-80GBで34B DPOのフルチューニング | ✅ | ❌ | ❌ | ❌ | +| PPOでの推論エンジンのサポート | ✅ | ✅ | ❌ | ❌ | +| PPO実装のトリック | ✅ | ❌ | ❌ | ✅ | +| QLoRAのサポート | ✅ | ❌ | ❌ | ✅ | +| Mixtral 8*7bのサポート | ✅ | ❌ | ❌ | ❌ | +| 未結合のActor-Criticのサポート | ✅ | ✅ | ✅ | ❌ | +| 複数の報酬モデルのサポート | ✅ | ❌ | ❌ | ❌ | +| Huggingfaceモデルのサポート | ✅ | ✅ | ✅ | ✅ | +| 使いやすさ | ✅ | ❌ (HybridEngineのバグ) | ✅ | ✅ | + +## クイックスタート + +### インストール + +OpenRLHFを使用するには、まずDockerコンテナを起動し(**推奨**)、Dockerコンテナ内で`pip install`を実行してopenrlhfをインストールします: + +```bash +# Dockerコンテナを起動 +docker run --runtime=nvidia -it --rm --shm-size="10g" --cap-add=SYS_ADMIN -v $PWD:/openrlhf nvcr.io/nvidia/pytorch:24.07-py3 bash +sudo pip uninstall xgboost transformer_engine flash_attn -y + +# pip install +pip install openrlhf + +# vLLM加速を使用する場合(vLLM 0.6.5をインストール) +pip install openrlhf[vllm] +# 最新のvLLMもサポートされています +pip install openrlhf[vllm_latest] + +# 最新バージョンをpip install +pip install git+https://github.com/OpenRLHF/OpenRLHF.git + +# またはgit clone +git clone https://github.com/OpenRLHF/OpenRLHF.git +cd OpenRLHF +pip install -e . +``` + +> [!NOTE] +>vLLM 0.6.4以降の使用をお勧めします。他のバージョン(vLLM >= 0.4.2)は、Glooを介して重みの同期が必要な場合があります(`--vllm_sync_backend gloo`)。 +>また、[vLLM用のDockerfile](./dockerfile/)および[Nvidia-Dockerのワンクリックインストールスクリプト](./examples/scripts/nvidia_docker_install.sh)も提供しています。 + +### データセットの準備 +OpenRLHFは、データセットクラス内で複数のデータ処理方法を提供しています。 +例えば、[Prompt Dataset](https://github.com/OpenRLHF/OpenRLHF/blob/main/openrlhf/datasets/prompts_dataset.py#L6)では: + +```python +def preprocess_data(data, input_template=None, input_key="input", apply_chat_template=None) -> str: + if apply_chat_template: + chat = data[input_key] + if isinstance(chat, str): + chat = [{"role": "user", "content": chat}] + prompt = apply_chat_template(chat, tokenize=False, add_generation_prompt=True) + else: + prompt = data[input_key] + if input_template: + prompt = input_template.format(prompt) + return prompt +``` + +- `--input_key`を使用して、入力データセットの`JSON key name`を指定し、`--prompt_data {name or path}`(PPO)または`--dataset {name or path}`を使用し、`--apply_chat_template`を使用して[Huggingface Tokenizer](https://huggingface.co/docs/transformers/main/en/chat_templating)の`chat_template`を利用できます。 +- `--apply_chat_template`を使用したくない場合は、代わりに`--input_template`を使用するか、事前にデータセットをオフラインで前処理することができます。 +- OpenRLHFは、`--prompt_data_probs 0.1,0.4,0.5`(PPO)または`--dataset_probs 0.1,0.4,0.5`を使用して複数のデータセットを混合することもサポートしています。 + +Chat Templatingの動作方法: + +```python +dataset = [{"input_key": [ + {"role": "user", "content": "Hello, how are you?"}, + {"role": "assistant", "content": "I'm doing great. How can I help you today?"}, + {"role": "user", "content": "I'd like to show off how chat templating works!"}, +]}] + +tokenizer.apply_chat_template(dataset[0]["input_key"], tokenize=False) + +"[INST] Hello, how are you? [/INST]I'm doing great. How can I help you today? [INST] I'd like to show off how chat templating works! [/INST]" +``` + +トレーニングおよびテストデータセットの指定方法: + +`data_type@data_dir`形式を使用して指定できます。例えば、データセットは`--dataset json@./data`として設定できます。 + +``` +data +├── test.jsonl +└── train.jsonl +``` + +> [!NOTE] +> デフォルトでは、`train`および`test`を使用してHuggingfaceのトレーニングおよびテストデータセットを区別します。 +> `JSON key`オプションは特定のデータセットに依存します。詳細は[Reward Dataset](https://github.com/OpenRLHF/OpenRLHF/blob/main/openrlhf/datasets/reward_dataset.py#L10)および[SFT Dataset](https://github.com/OpenRLHF/OpenRLHF/blob/main/openrlhf/datasets/sft_dataset.py#L9)を参照してください。 + +### 教師あり微調整 + +OpenRLHFのモデルチェックポイントはHuggingFaceモデルと完全に互換性があります。`--pretrain {name or path}`、`--reward_pretrain {name or path}`、および`--critic_pretrain {name or path}`を使用してモデル名またはパスを指定できます。いくつかの事前トレーニング済みチェックポイントとデータセットを[HuggingFace OpenRLHF](https://huggingface.co/OpenRLHF)で提供しています。 + +次に、[examples/scripts](./examples/scripts/)ディレクトリに提供されている起動スクリプトを使用するか、以下のコマンドを使用してトレーニングを開始できます。 + +```bash +deepspeed --module openrlhf.cli.train_sft \ + --max_len 4096 \ + --dataset Open-Orca/OpenOrca \ + --input_key question \ + --output_key response \ + --input_template $'User: {}\nAssistant: ' \ + --train_batch_size 256 \ + --micro_train_batch_size 2 \ + --max_samples 500000 \ + --pretrain meta-llama/Meta-Llama-3-8B \ + --save_path ./checkpoint/llama3-8b-sft \ + --save_steps -1 \ + --logging_steps 1 \ + --eval_steps -1 \ + --zero_stage 2 \ + --max_epochs 1 \ + --packing_samples \ + --bf16 \ + --flash_attn \ + --learning_rate 5e-6 \ + --gradient_checkpointing \ + --use_wandb {wandb_token} + +# HF tokenizer.apply_chat_templateのサポート +# --apply_chat_template +# --tokenizer_chat_template {HF Chat Template} + +# RingAttentionのサポート +# pip install ring_flash_attn +# --ring_attn_size 2 \ +# --ring_head_stride 2 \ + +# 継続的な事前トレーニングにも使用できます +# --pretrain_mode +``` + +> [!NOTE] +> OpenRLHF SFT/DPO/RewardModel/PPOトレーナーは`--packing_samples`をサポートしています [`--flash_attn`に基づく](https://github.com/MeetKai/functionary/tree/main/functionary/train/packing) + +### 報酬モデルのトレーニング +```bash +deepspeed --module openrlhf.cli.train_rm \ + --save_path ./checkpoint/llama3-8b-rm \ + --save_steps -1 \ + --logging_steps 1 \ + --eval_steps -1 \ + --train_batch_size 256 \ + --micro_train_batch_size 1 \ + --pretrain OpenRLHF/Llama-3-8b-sft-mixture \ + --bf16 \ + --max_epochs 1 \ + --max_len 8192 \ + --zero_stage 3 \ + --learning_rate 9e-6 \ + --dataset OpenRLHF/preference_dataset_mixture2_and_safe_pku \ + --apply_chat_template \ + --chosen_key chosen \ + --rejected_key rejected \ + --flash_attn \ + --packing_samples \ + --gradient_checkpointing \ + --use_wandb {wandb_token} + +``` + +報酬モデルの`--value_prefix_head`オプションを`score`に設定することをお勧めします。これにより、`AutoModelForSequenceClassification`を使用してモデルをロードできます: + +```python +reward_model = AutoModelForSequenceClassification.from_pretrained( + reward_model_path, + num_labels=1, + torch_dtype=torch.bfloat16, + attn_implementation="flash_attention_2", + use_cache=False, + ) +inputs = xxxx (Left Padding Input Tokens) +reward = reward_model.model(*inputs).last_hidden_state +reward = reward_model.score(reward)[:, -1] +``` + +### Rayを使用しないPPO + +```bash +deepspeed --module openrlhf.cli.train_ppo \ + --pretrain OpenRLHF/Llama-3-8b-sft-mixture \ + --reward_pretrain OpenRLHF/Llama-3-8b-rm-mixture \ + --save_path ./checkpoint/llama-3-8b-rlhf \ + --save_steps -1 \ + --logging_steps 1 \ + --eval_steps -1 \ + --micro_train_batch_size 2 \ + --train_batch_size 128 \ + --micro_rollout_batch_size 4 \ + --rollout_batch_size 1024 \ + --max_epochs 1 \ + --prompt_max_len 1024 \ + --generate_max_len 1024 \ + --zero_stage 2 \ + --bf16 \ + --actor_learning_rate 5e-7 \ + --critic_learning_rate 9e-6 \ + --init_kl_coef 0.01 \ + --prompt_data OpenRLHF/prompt-collection-v0.1 \ + --input_key context_messages \ + --apply_chat_template \ + --max_samples 100000 \ + --normalize_reward \ + --adam_offload \ + --flash_attn \ + --gradient_checkpointing \ + --use_wandb {wandb_token} + +# リモート報酬モデルのサポート(HTTP) +# --remote_rm_url http://localhost:5000/get_reward +``` + +### RayとvLLMを使用したPPO/REINFORCE++ + +RLHFトレーニング速度を向上させるか、70Bモデルをサポートするために、RayとvLLM加速を使用したPPOを使用できます + +```bash +# コンテナ内でRayのマスターノードを起動 +ray start --head --node-ip-address 0.0.0.0 --num-gpus 8 + +# さらに多くのノードでRayを起動する場合は +ray start --address {MASTER-NODE-ADDRESS}:6379 --num-gpus 8 + +ray job submit --address="http://127.0.0.1:8265" \ + --runtime-env-json='{"working_dir": "/openrlhf"}' \ + -- python3 -m openrlhf.cli.train_ppo_ray \ + --ref_num_nodes 1 \ + --ref_num_gpus_per_node 2 \ + --reward_num_nodes 1 \ + --reward_num_gpus_per_node 2 \ + --critic_num_nodes 1 \ + --critic_num_gpus_per_node 2 \ + --actor_num_nodes 1 \ + --actor_num_gpus_per_node 2 \ + --vllm_num_engines 2 \ + --vllm_tensor_parallel_size 2 \ + --colocate_critic_reward \ + --colocate_actor_ref \ + --pretrain OpenRLHF/Llama-3-8b-sft-mixture \ + --reward_pretrain OpenRLHF/Llama-3-8b-rm-mixture \ + --save_path /openrlhf/examples/checkpoint/llama3-8b-rlhf \ + --micro_train_batch_size 8 \ + --train_batch_size 128 \ + --micro_rollout_batch_size 16 \ + --rollout_batch_size 1024 \ + --max_samples 100000 \ + --max_epochs 1 \ + --prompt_max_len 1024 \ + --generate_max_len 1024 \ + --zero_stage 3 \ + --bf16 \ + --actor_learning_rate 5e-7 \ + --critic_learning_rate 9e-6 \ + --init_kl_coef 0.01 \ + --prompt_data OpenRLHF/prompt-collection-v0.1 \ + --input_key context_messages \ + --apply_chat_template \ + --normalize_reward \ + --packing_samples \ + --adam_offload \ + --flash_attn \ + --gradient_checkpointing \ + --use_wandb {wandb_token} + +# REINFORCE++ | RLOOのサポート +# --advantage_estimator reinforce | rloo + +# リモート報酬モデルのサポート(HTTP) +# --remote_rm_url http://localhost:5000/get_reward + + +# Nサンプルのサポート +# --n_samples_per_prompt 4 +``` +> [!NOTE] +> `--vllm_num_engines`を設定しない場合は、vLLMエンジンを使用しないことを意味します。 +> `setup_commands`を使用してRayが自動的に環境をデプロイすることもできます。例えば、`--runtime-env-json='{"setup_commands": ["pip install openrlhf[vllm]"]}'`。 + +[!NOTE] +OPENRLHFのRLOOは、REINFORCE++を基に改良されたものであり、オリジナル版とは異なります。 + +> [!NOTE] +> deepspeedがGPUデバイスを設定する際にインデックスが範囲外に関連するエラーが発生した場合、環境変数[`RAY_EXPERIMENTAL_NOSET_*_VISIBLE_DEVICES`](openrlhf/trainer/ray/utils.py)を設定して回避策を試すことができます。 +> ```bash +> # NVIDIA GPUの場合: +> export RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES=1 +> ``` + +サポートされているアルゴリズムの起動スクリプトとドキュメントは[example/scripts](./examples/scripts/)および[Documents - Usage](https://openrlhf.readthedocs.io/en/latest/usage.html)にあります。 + +## パフォーマンス + +Adamオフロードの有効化、報酬モデル(RM)および参照モデル(Ref)オフロードなどの技術を使用して、DSChatのパフォーマンスを最大限に最適化し、推論段階でのマイクロバッチサイズを増やし、メモリ不足の問題を回避しました。LLaMA2のハイブリッドエンジン(HE)を有効にするために、DSChatのいくつかのバグも修正しました。Optimized DSChatとOpenRLHFを使用して1024のプロンプトを1つのPPOエポックでトレーニングするのにかかる平均時間(秒)は次のとおりです: + +| **サイズ** | **NVIDIA A800-80GB GPU** | **Optimized DSChat(ハイブリッドエンジン付き)** | **OpenRLHF** | **スピードアップ** | +| :---: | :---: | :---: | :---: | :---: | +| 7B | 16 | 855.09 | 471.11 | 1.82x | +| 13B | 32 | 1528.93 | 608.93 | 2.5x | +| 34B | 32 | 3634.98 | 1526.4 | 2.4x | +| 70B | 32 | 10407.0 | 4488.53 | 2.3x | + +> [!NOTE] +> データは古くなっています。再テストのためにパフォーマンスチューニングセクションを参照してください。 + +### パフォーマンスチューニングガイド + +最適なパフォーマンスを達成するために、vLLMエンジンにより多くのノードを割り当てることをお勧めします。例えば、32個のA100 GPUを持つ70Bモデルの場合、16個のA100 GPUをvLLMエンジンに割り当て、8個のGPUをActorモデルに、残りの8個のGPUをCriticモデルに割り当てることをお勧めします。さらに、`--colocate_critic_reward`、`--colocate_actor_ref`オプションを有効にしてノードをマージします。最後に、`rollout_micro_batch_size`(およびvLLMエンジンのTPサイズを最小化)を可能な限り増やすべきです。トレーニングフェーズでは、より大きな`--micro_train_batch_size`が望ましく、`--packing_samples`を有効にします。十分なGPUがある場合、`--adam_offload`を無効にし、`--overlap_comm`を有効にします。マルチノードRLHFの場合、vLLM 0.6.4+で`--vllm_sync_backend nccl`を使用してください。 + +## OpenRLHFを使用している企業と組織 + +- Google +- ByteDance +- Tencent +- Alibaba +- Baidu +- China Telecom +- Vivo +- Allen AI +- NexusFlow +- Jülich Supercomputing Centre (JSC) +- Berkeley Starling Team +- M-A-P +- ... + +## 参加方法 + +**参加方法** + +1. janhu9527@gmail.comにメールを送るか、[GitHub Organization](https://github.com/OpenRLHF)に参加してください。以下の詳細を含めてください: + - あなたの名前 + - あなたのGitHubユーザー名 + - あなたの興味のある分野 + - NLPおよび/またはAIに関連するスキルと経験 +1. 公式GitHub[OpenRLHF ↗](https://github.com/OpenRLHF/OpenRLHF)プロジェクトページを通じて参加することもできます。貢献したい興味についてのissueを作成するだけで、私たちが連絡します。 + +**何ができるか** + +1. チームに参加し、OpenRLHFプロジェクトの開発に参加します。 +1. プロジェクトに貢献するためにプルリクエストを提出します。 +1. ドキュメントの改善、バグの修正、新機能の作成を手伝います。 +1. プロジェクトを共有し、コミュニティの成長を支援します。 + +## スポンサー + +スポンサーシップは、OpenRLHFの維持と改善に役立ちます。このプロジェクトが役立つと感じた場合は、スポンサーを検討してください。[Open Collective ↗](https://opencollective.com/OpenRLHF)でスポンサーになることができます。 + +## スター履歴 + +[![Star History Chart](https://api.star-history.com/svg?repos=OpenRLHF/OpenRLHF&type=Date)](https://star-history.com/#OpenRLHF/OpenRLHF&Date) + +## 貢献者 + +すべての貢献者に感謝します!貢献したい場合は、プルリクエストを作成するか、issueを作成してください。 + + + + + +## 参考文献と謝辞 + +AIおよびNLP分野への貢献に対して、以下のプロジェクトおよび組織に感謝します: + +- [Hugging Face Transformers ↗](https://github.com/huggingface/transformers) +- [OpenAI GPT ↗](https://github.com/openai/gpt-3) +- [LLaMA ↗](https://llama.meta.com/) +- [DeepSpeed ↗](https://github.com/microsoft/DeepSpeed) +- [Ray ↗](https://github.com/ray-project/ray) + +私たちのプロジェクトは、[ColossalChat](https://github.com/hpcaitech/ColossalAI/tree/main/applications/Chat)および[DeepSpeedChat](https://github.com/microsoft/DeepSpeedExamples/tree/master/applications/DeepSpeed-Chat)にも感謝します。プロジェクトの初期段階で、彼らのコードデザインを参考にしました。 + +(2024/7) 私たちのGitHub組織はOpenLLMAIからOpenRLHFに変更されました。 + +## 引用 +``` +@article{hu2024openrlhf, + title={OpenRLHF: An Easy-to-use, Scalable and High-performance RLHF Framework}, + author={Jian Hu and Xibin Wu and Zilin Zhu and Xianyu and Weixun Wang and Dehao Zhang and Yu Cao}, + journal={arXiv preprint arXiv:2405.11143}, + year={2024} +} +``` + +______________________________________________________________________ + +*OpenRLHF © 2025 OpenRLHF. All Rights Reserved.* diff --git a/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/README_zh.md b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/README_zh.md new file mode 100644 index 0000000000..f7f0bedbf9 --- /dev/null +++ b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/README_zh.md @@ -0,0 +1,482 @@ +
+

+ +

+
+ +
+

+ + GitHub Contributors + + + Issues + + + Issues + + + GitHub pull requests + + GitHub stars + +
+ 开源 / 全面 / 轻量级 / 易用 +

+

+
+ +
+ +[ English | 中文 | 日本語 ] + +OpenRLHF 是一个基于 Ray、DeepSpeed 和 HF Transformers 构建的高性能 RLHF 框架: + +- **简单易用**: OpenRLHF 是目前可用的最简单的高性能 RLHF 库之一,无缝兼容 Huggingface 模型和数据集。 +- **高性能**: RLHF 训练中 80% 的时间用于样本生成阶段。得益于使用 Ray, Packing Samples 以及 vLLM 生成加速的能力,OpenRLHF 的性能是极致优化的 DeepSpeedChat with Hybrid Engine 的3~4倍以上。 +- **分布式 RLHF**: OpenRLHF 使用 Ray 将 Actor、Reward、Reference 和 Critic 模型分布到不同的 GPU 上,同时将 Adam 优化器放在 CPU 上。这使得使用多个 A100 80G GPU 和 vLLM 可以全面微调超过 70B+ 的模型 以及在多个 24GB RTX 4090 GPU 上微调 7B 模型。 +- **PPO 实现技巧**: 我们集成了 PPO 的实现技巧以提高训练稳定性,详情参考 [知乎](https://zhuanlan.zhihu.com/p/622134699) 和 [Notion blog](https://hijkzzz.notion.site/rlhf-implementation-tricks?v=158d9a33ecc98132bf9e000c39227361). + +更多细节请参考 [PPT](https://docs.google.com/presentation/d/1JRhB1d7csofx0PIZBmfyBdMluxNd5JLPpUHrrvVhGnk/edit?usp=sharing) | [技术报告](https://arxiv.org/abs/2405.11143) | [使用文档](https://openrlhf.readthedocs.io/) + + +## 新闻 +- [2024/12] 我们"提出"了 😊 [REINFORCE++ 对齐算法](https://www.researchgate.net/publication/387487679_REINFORCE_A_SIMPLE_AND_EFFICIENT_APPROACH_FOR_ALIGNING_LARGE_LANGUAGE_MODELS). +- [2024/12] 在 [Notion Blog](https://hijkzzz.notion.site/unraveling-rlhf-and-its-variants-engineering-insights#147d9a33ecc9806090f3d5c749d31f05) 中,我们对 PPO、REINFORCE++、GRPO 和 RLOO 进行了分析。 + +## 特性 + +- 基于 Ray 的分布式 [PPO](./examples/scripts/train_ppo_llama_ray.sh) 和 [REINFORCE++/RLOO](./examples/scripts/train_reinforce_llama_ray.sh) 实现。 +- 支持对 [超过 700 亿参数的模型](./examples/scripts/train_ppo_llama_ray_70b.sh) 进行完整的 RLHF 微调。 +- 集成 vLLM,加速 RLHF 任务中的样本生成(`--vllm_num_engines`)。 +- 支持多个奖励模型(`--reward_pretrain model1,model2...`)和远程奖励模型(`--remote_rm_url`)。 +- 实现 [DPO(直接偏好优化)/IPO/cDPO](./examples/scripts/train_dpo_llama.sh) 和 [Kahneman-Tversky Optimization(KTO)](./examples/scripts/train_kto_llama.sh)。 +- 支持 [迭代 DPO](./examples/scripts/train_iterative_dpo_llama.sh)([GitHub: Online-RLHF](https://github.com/RLHFlow/Online-RLHF))。 +- 支持 [拒绝采样](./examples/scripts/train_rejection_sampling_llama.sh)。 +- 实现 [条件 SFT](./examples/scripts/train_conditional_llama.sh)([arXiv:2308.12050](https://arxiv.org/abs/2308.12050))。 +- 支持 [知识蒸馏](./examples/scripts/train_knowledge_distillation.sh)([Microsoft: minillm](https://github.com/microsoft/LMOps/tree/main/minillm))。 +- 集成 [过程奖励模型(PRM)](./examples/scripts/train_prm_mistral.sh)。 +- 支持 SFT、DPO、RM、PRM 和 PPO 的训练样本打包(`--packing_samples`)。 +- 实现 [RingAttention](./examples/scripts/train_dpo_ring_llama.sh)(`--ring_attn_size`,`--ring_head_stride`)。 +- 支持 [专家混合模型(MoE)](./examples/test_scripts/train_sft_mixtral_lora.sh)(`--aux_loss_coef`)。 +- 集成 FlashAttention2(`--flash_attn`)。 +- 支持 QLoRA(`--load_in_4bit`)和 [LoRA](./examples/scripts/train_sft_mixtral_lora.sh)(`--lora_rank`,`--target_modules`)。 +- 兼容 HuggingFace 的 `tokenizer.apply_chat_template` 数据集格式(`--apply_chat_template` 和 `--input_key`)。 +- 支持使用 Wandb(`--use_wandb`)和 TensorBoard(`--use_tensorboard`)进行日志记录。 +- 支持从检查点恢复训练(`--load_checkpoint` 和 `--save_steps`)。 +- 提供了多节点训练脚本, 比如 [DPO](./examples/scripts/train_llama_slurm.sh) 和 [RLHF](./examples/scripts/train_ppo_llama_ray_slurm.sh) + + +### PPO 支持矩阵 + +| 特性 | OpenRLHF | DSChat | CAIChat | TRL | +| ------------- |:-------------:| :-------------:| :-------------:| :-------------:| +| 使用 16 个 A100 完成 70B+ 全微调 | ✅ | ❌ | ❌ | ❌ || +| 使用 4 个 RTX4090 完成 7B 全微调 | ✅ | ❌ | ❌ | ❌ | +| 使用 8 个 A100 完成 34B DPO 全微调 | ✅ | ❌ | ❌ | ❌ | +| 支持推理引擎加速 | ✅ | ✅ | ❌ | ❌ | +| PPO 实现技巧 | ✅ | ❌ | ❌ | ✅ | +| 支持 QLoRA | ✅ | ❌ | ❌ | ✅ | +| 支持 Mixtral 8*7b | ✅ | ❌ | ❌ | ❌ | +| 支持未合并的 Actor-Critic | ✅ | ✅ | ✅ | ❌ | +| 支持多个奖励模型 | ✅ | ❌ | ❌ | ❌ | +| 支持 Huggingface 模型 | ✅ | ✅ | ✅ | ✅ | +| 易于使用 | ✅ | ❌ (HybridEngine bugs) | ✅ | ✅ | + +## 快速开始 + +### 安装 + +要使用 OpenRLHF,首先启动 Docker 容器(**推荐**)然后执行 `pip install` 安装 `openrlhf`: + +```bash +# 启动 docker container +docker run --runtime=nvidia -it --rm --shm-size="10g" --cap-add=SYS_ADMIN -v $PWD:/openrlhf nvcr.io/nvidia/pytorch:24.07-py3 bash +sudo pip uninstall xgboost transformer_engine flash_attn -y + +# pip install +pip install openrlhf + +# 如果你需要使用 vLLM 加速 (安装 vLLM 0.6.5) +pip install openrlhf[vllm] +# 最新的 vLLM 也是支持的 +pip install openrlhf[vllm_latest] + +# pip install GitHub 上的最新版 +pip install git+https://github.com/OpenRLHF/OpenRLHF.git + +# 或者 git clone +git clone https://github.com/OpenRLHF/OpenRLHF.git +cd OpenRLHF +pip install -e . +``` + +> [!NOTE] +>我们推荐使用 vLLM 0.6.4+,其他版本 (vLLM >= 0.4.2) 可能需要通过 Gloo 进行权重同步(`--vllm_sync_backend gloo`)。 +>我们也提供了 [Dockerfiles for vLLM](./dockerfile/) 和 [Nvidia-Docker 一键安装脚本](./examples/scripts/nvidia_docker_install.sh)。 + +### 准备数据集 +OpenRLHF 在其数据集类中提供了多种数据处理方法。 +例如在 [Prompt Dataset](https://github.com/OpenRLHF/OpenRLHF/blob/main/openrlhf/datasets/prompts_dataset.py#L6) 中: + +```python +def preprocess_data(data, input_template=None, input_key="input", apply_chat_template=None) -> str: + if apply_chat_template: + chat = data[input_key] + if isinstance(chat, str): + chat = [{"role": "user", "content": chat}] + prompt = apply_chat_template(chat, tokenize=False, add_generation_prompt=True) + else: + prompt = data[input_key] + if input_template: + prompt = input_template.format(prompt) + return prompt +``` + +- 我们可以使用 `--input_key` 指定 `JSON key name` 为输入数据集 `--prompt_data {name or path}` (PPO) 或 `--dataset {name or path}`,并使用 `--apply_chat_template` 利用 [Huggingface Tokenizer](https://huggingface.co/docs/transformers/main/en/chat_templating) 中的 `chat_template`。 +- 如果不想使用 `--apply_chat_template`,可以改用 `--input_template`,或预先离线处理数据集。 +- OpenRLHF 还支持使用 `--prompt_data_probs 0.1,0.4,0.5` (PPO) 或 `--dataset_probs 0.1,0.4,0.5` 混合多个数据集。 + +Chat Templating 的工作原理如下: + +```python +dataset = [{"input_key": [ + {"role": "user", "content": "Hello, how are you?"}, + {"role": "assistant", "content": "I'm doing great. How can I help you today?"}, + {"role": "user", "content": "I'd like to show off how chat templating works!"}, +]}] + +tokenizer.apply_chat_template(dataset[0]["input_key"], tokenize=False) + +"[INST] Hello, how are you? [/INST]I'm doing great. How can I help you today? [INST] I'd like to show off how chat templating works! [/INST]" +``` + +如何指定训练和测试数据分区 ? + +你可以使用 `data_type@data_dir` 的方式指定, 比如下面的数据集可以设置为 `--dataset json@./data` + +``` +data +├── test.jsonl +└── train.jsonl +``` + +> [!NOTE] +>默认情况下我们使用 `train` 和 `test` 作为 split 区分 Huggingface 的训练/测试数据。 +>`JSON key` 选项取决于具体的数据集。请参阅 [Reward Dataset](https://github.com/OpenRLHF/OpenRLHF/blob/main/openrlhf/datasets/reward_dataset.py#L10) 和 [SFT Dataset](https://github.com/OpenRLHF/OpenRLHF/blob/mai + + +### Supervised Fine-tuning + +OpenRLHF 的模型检查点完全兼容 HuggingFace 模型。您可以使用 `--pretrain {name or path}`、`--reward_pretrain {name or path}` 和 `--critic_pretrain {name or path}` 指定模型名称或路径。我们在 [HuggingFace OpenRLHF](https://huggingface.co/OpenRLHF) 上提供了一些预训练的检查点和数据集。 + +然后您可以使用我们在 [examples/scripts](./examples/scripts/) 目录中提供的启动脚本,或者使用以下命令启动训练: + +```bash +deepspeed --module openrlhf.cli.train_sft \ + --max_len 4096 \ + --dataset Open-Orca/OpenOrca \ + --input_key question \ + --output_key response \ + --input_template $'User: {}\nAssistant: ' \ + --train_batch_size 256 \ + --micro_train_batch_size 2 \ + --max_samples 500000 \ + --pretrain meta-llama/Meta-Llama-3-8B \ + --save_path ./checkpoint/llama3-8b-sft \ + --save_steps -1 \ + --logging_steps 1 \ + --eval_steps -1 \ + --zero_stage 2 \ + --max_epochs 1 \ + --bf16 \ + --flash_attn \ + --learning_rate 5e-6 \ + --gradient_checkpointing \ + --packing_samples \ + --load_checkpoint \ + --use_wandb {wandb_token} + +# 支持 HF tokenizer.apply_chat_template +# --apply_chat_template +# --tokenizer_chat_template {HF Chat Template} + +# 支持 RingAttention +# pip install ring_flash_attn +# --ring_attn_size 2 \ +# --ring_head_stride 2 \ + +# 也可用于 continued pre-training +# --pretrain_mode +``` + +> [!NOTE] +> OpenRLHF SFT/DPO/RewardModel/PPO 训练支持 `--packing_samples` [基于 `--flash_attn`](https://github.com/MeetKai/functionary/tree/main/functionary/train/packing) + +### Reward Model Training +```bash +deepspeed --module openrlhf.cli.train_rm \ + --save_path ./checkpoint/llama3-8b-rm \ + --save_steps -1 \ + --logging_steps 1 \ + --eval_steps -1 \ + --train_batch_size 256 \ + --micro_train_batch_size 1 \ + --pretrain OpenRLHF/Llama-3-8b-sft-mixture \ + --bf16 \ + --max_epochs 1 \ + --max_len 8192 \ + --zero_stage 3 \ + --learning_rate 9e-6 \ + --dataset OpenRLHF/preference_dataset_mixture2_and_safe_pku \ + --apply_chat_template \ + --chosen_key chosen \ + --rejected_key rejected \ + --flash_attn \ + --packing_samples \ + --gradient_checkpointing \ + --load_checkpoint \ + --use_wandb {wandb_token} + +``` + +推荐设置 Reward Model 的 `--value_prefix_head` 选项为 `score`, 这样使得我们可以用 `AutoModelForSequenceClassification` 加载模型: + +```python +reward_model = AutoModelForSequenceClassification.from_pretrained( + reward_model_path, + num_labels=1, + torch_dtype=torch.bfloat16, + attn_implementation="flash_attention_2", + use_cache=False, + ) +inputs = xxxx (Left Padding Input Tokens) +reward = reward_model.model(*inputs).last_hidden_state +reward = reward_model.score(reward)[:, -1] +``` + +### 不使用 Ray 的 PPO + +```bash +deepspeed --module openrlhf.cli.train_ppo \ + --pretrain OpenRLHF/Llama-3-8b-sft-mixture \ + --reward_pretrain OpenRLHF/Llama-3-8b-rm-mixture \ + --save_path ./checkpoint/llama-3-8b-rlhf \ + --save_steps -1 \ + --logging_steps 1 \ + --eval_steps -1 \ + --micro_train_batch_size 2 \ + --train_batch_size 128 \ + --micro_rollout_batch_size 4 \ + --rollout_batch_size 1024 \ + --max_epochs 1 \ + --prompt_max_len 1024 \ + --generate_max_len 1024 \ + --zero_stage 2 \ + --bf16 \ + --actor_learning_rate 5e-7 \ + --critic_learning_rate 9e-6 \ + --init_kl_coef 0.01 \ + --prompt_data OpenRLHF/prompt-collection-v0.1 \ + --input_key context_messages \ + --apply_chat_template \ + --max_samples 100000 \ + --normalize_reward \ + --adam_offload \ + --flash_attn \ + --gradient_checkpointing \ + --load_checkpoint \ + --use_wandb {wandb_token} + +# 支持远程 reward model (HTTP) +# --remote_rm_url http://localhost:5000/get_reward +``` + +### 使用 Ray 和 vLLM 的 PPO/REINFORCE++ + +为了提高 RLHF 训练速度或支持 70B 模型,我们可以使用 Ray 和 vLLM 加速的 PPO + +```bash +# 在容器中启动 Ray 的主节点 +ray start --head --node-ip-address 0.0.0.0 --num-gpus 8 + +# 如果要在更多节点上启动 Ray,请使用 +ray start --address {MASTER-NODE-ADDRESS}:6379 --num-gpus 8 + +ray job submit --address="http://127.0.0.1:8265" \ + --runtime-env-json='{"working_dir": "/openrlhf"}' \ + -- python3 -m openrlhf.cli.train_ppo_ray \ + --ref_num_nodes 1 \ + --ref_num_gpus_per_node 2 \ + --reward_num_nodes 1 \ + --reward_num_gpus_per_node 2 \ + --critic_num_nodes 1 \ + --critic_num_gpus_per_node 2 \ + --actor_num_nodes 1 \ + --actor_num_gpus_per_node 2 \ + --vllm_num_engines 2 \ + --vllm_tensor_parallel_size 2 \ + --colocate_critic_reward \ + --colocate_actor_ref \ + --pretrain OpenRLHF/Llama-3-8b-sft-mixture \ + --reward_pretrain OpenRLHF/Llama-3-8b-rm-mixture \ + --save_path /openrlhf/examples/checkpoint/llama3-8b-rlhf \ + --micro_train_batch_size 8 \ + --train_batch_size 128 \ + --micro_rollout_batch_size 32 \ + --rollout_batch_size 1024 \ + --max_samples 100000 \ + --max_epochs 1 \ + --prompt_max_len 1024 \ + --generate_max_len 1024 \ + --zero_stage 3 \ + --bf16 \ + --actor_learning_rate 5e-7 \ + --critic_learning_rate 9e-6 \ + --init_kl_coef 0.01 \ + --prompt_data OpenRLHF/prompt-collection-v0.1 \ + --input_key context_messages \ + --apply_chat_template \ + --normalize_reward \ + --packing_samples \ + --adam_offload \ + --flash_attn \ + --gradient_checkpointing \ + --load_checkpoint \ + --use_wandb {wandb_token} + +# 支持 REINFORCE++ | RLOO +# --advantage_estimator reinforce | rloo + +# 支持远程 reward model (HTTP) +# --remote_rm_url http://localhost:5000/get_reward + +# 支持 N 倍采样 +# --n_samples_per_prompt 4 +``` + +> [!NOTE] +> 不设置 `--vllm_num_engines` 则是不使用 vLLM engine。 +> 您也可以通过 ``setup_commands`` 让 Ray 自动初始化环境, 比如 `--runtime-env-json='{"setup_commands": ["pip install openrlhf[vllm]"]}'` + +> [!NOTE] +> OPENRLHF's RLOO 基于 REINFORCE++ 修改而来, 和原版的实现不同. + +> [!NOTE] +> 如果您由于某种原因,在 deepspeed 设置显卡设备时遇到与索引超出范围相关的错误,您可以尝试设置环境变量 [`RAY_EXPERIMENTAL_NOSET_*_VISIBLE_DEVICES`](openrlhf/trainer/ray/utils.py)。 +> ```bash +> # 对于 NVIDIA 显卡: +> export RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES=1 +> ``` + +所有支持算法的启动脚本和文档在 [example/scripts](./examples/scripts/) 和 [Documents - Usage](https://openrlhf.readthedocs.io/en/latest/usage.html) + + +### 使用 LoRA +如果您使用了 `LoRA (Low-Rank Adaptation)`,默认保存下来的文件**并非**完整模型权重,而是 `LoRA Adapter`,若想按完整权重的方式进行后续任务,您需要将 `Adapter` 与训练前的模型权重进行合并 + +```bash +python -m openrlhf.cli.lora_combiner \ + --model_path meta-llama/Meta-Llama-3-8B \ + --lora_path ./checkpoint/llama3-8b-rm \ + --output_path ./checkpoint/llama-3-8b-rm-combined \ + --is_rm \ + --bf16 +``` + +## 性能 +我们通过启用Adam卸载、奖励模型(RM)和参考模型(Ref)卸载等技术,尽可能优化了DSChat的性能,从而在推理阶段增加小批量大小并避免内存不足问题。我们甚至修复了DSChat中的一些bug,以启用LLaMA2的混合引擎(HE)。使用优化后的DSChat和OpenRLHF训练1024个提示需要1个PPO轮次的平均时间(秒)如下: + +| **Size** | **NVIDIA A800 GPUs** | **Optimized DSChat (with Hybrid Engine)** | **OpenRLHF** | **Speedup** | +| :---: | :---: | :---: | :---: | :---: | +| 7B | 16 | 855.09 | 471.11 | 1.82x | +| 13B | 32 | 1528.93 | 608.93 | 2.5x | +| 34B | 32 | 3634.98 | 1526.4 | 2.4x | +| 70B | 32 | 10407.0 | 4488.53 | 2.3x | + + +> [!NOTE] +> 数据已经过时; 请参考后面的调优指南重新测试 + +## 调优指南 +为了获得最佳的性能,我们建议您分配更多的节点给 vLLM Engine。例如,对于 70B 模型以及 32 张 A100,建议分配 16 张以上 A100 给 vLLM Engine,8 张给 Actor 模型,以及最后 8 张给 Critic 模型,同时开启 `--colocate_critic_reward`, `--colocate_actor_ref` 或者 `--ref_reward_offload (可选)` 选项合并部分节点。最后您应该尽可能增大 `--rollout_micro_batch_size` ,以及减小 vLLM 的 TP 切分数量。训练阶段的 `micro_train_batch_size` 也是越大越好,请同时使用 `--packing_samples` 。当 GPU 数量足够时请关闭 `--adam_offload` 以及启用 `--overlap_comm`. 对于多节点 RLHF, 请使用 `--vllm_sync_backend nccl` with vLLM 0.6.4+. + +## 使用 OpenRLHF 的公司和组织 + +- Google +- ByteDance +- Tencent +- Alibaba +- Baidu +- China Telecom +- Allen AI +- Vivo +- NexusFlow +- Jülich Supercomputing Centre (JSC) +- Berkeley Starling Team +- M-A-P +- ... + + +## 加入我们 + +**如何加入?** + +1. 通过联系邮箱 janhu9527@gmail.com 或者加入 [GitHub Organization](https://github.com/OpenRLHF)。请包含以下信息: + - 您的姓名 + - 您的 GitHub 用户名 + - 您感兴趣的领域 + - 您在 NLP 和/或 AI 相关的技能和经验 +2. 您也可以通过官方 GitHub [OpenRLHF ↗](https://github.com/OpenRLHF/OpenRLHF) 项目页面加入我们。只需创建一个关于您想要贡献的兴趣的 issue,我们会与您联系。 + +**您能做什么?** + +1. 加入团队,参与 OpenRLHF 项目的开发。 +2. 通过提交 pull 请求来为项目做出贡献。 +3. 帮助改进文档,修复 bugs 或创建新功能。 +4. 分享项目并帮助我们发展社区。 + +## 赞助我们 + +您的赞助可以帮助我们维护和改进 OpenRLHF。如果您觉得这个项目有用,请考虑赞助我们。您可以在 [Open Collective ↗](https://opencollective.com/OpenRLHF) 上赞助我们。 + +## 星图 + +[![Star History Chart](https://api.star-history.com/svg?repos=OpenRLHF/OpenRLHF&type=Date)](https://star-history.com/#OpenRLHF/OpenRLHF&Date) + +## 贡献者 + +非常感谢所有的贡献者!如果您想贡献,请随时创建 pull 请求或创建 issue。 + + + + + +## 引用与致谢 + +我们想要对以下项目和组织在 AI 和 NLP 领域的贡献表示感谢: + +- [Hugging Face Transformers ↗](https://github.com/huggingface/transformers) +- [OpenAI GPT ↗](https://github.com/openai/gpt-3) +- [LLaMA ↗](https://llama.meta.com/) +- [DeepSpeed ↗](https://github.com/microsoft/DeepSpeed) +- [Ray ↗](https://github.com/ray-project/ray) + +我们的项目还想要感谢 [ColossalChat](https://github.com/hpcaitech/ColossalAI/tree/main/applications/Chat) 和 [DeepSpeedChat](https://github.com/microsoft/DeepSpeedExamples/tree/master/applications/DeepSpeed-Chat)。在项目的早期阶段,我们参考了他们的代码设计。 + +(2024/7) 我们的 GitHub 组织从 OpenLLMAI 迁移到了 OpenRLHF. + +## 引用 +``` +@article{hu2024openrlhf, + title={OpenRLHF: An Easy-to-use, Scalable and High-performance RLHF Framework}, + author={Jian Hu and Xibin Wu and Zilin Zhu and Xianyu and Weixun Wang and Dehao Zhang and Yu Cao}, + journal={arXiv preprint arXiv:2405.11143}, + year={2024} +} +``` + + +______________________________________________________________________ + +*OpenRLHF © 2025 OpenRLHF. 版权所有。* diff --git a/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/dockerfile/Dockerfile b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/dockerfile/Dockerfile new file mode 100644 index 0000000000..bbc1091ba8 --- /dev/null +++ b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/dockerfile/Dockerfile @@ -0,0 +1,23 @@ +FROM nvcr.io/nvidia/pytorch:24.07-py3 + +WORKDIR /app + +RUN set -eux && \ + apt-get update && \ + apt-get install -y gosu && \ + rm -rf /var/lib/apt/lists/* && \ + gosu nobody true + +RUN apt-get update && apt-get -y install sudo +RUN sudo su - + +RUN DEBIAN_FRONTEND=noninteractive apt install -y tzdata + +RUN apt-get -y install build-essential git python3-dev python3-pip libopenexr-dev libxi-dev libglfw3-dev libglew-dev libomp-dev libxinerama-dev libxcursor-dev gdb +RUN pip uninstall xgboost transformer_engine flash_attn -y +RUN pip install vllm==0.6.5 + +COPY docker-entrypoint.sh . +RUN chmod a+x docker-entrypoint.sh + +ENTRYPOINT ["/app/docker-entrypoint.sh"] diff --git a/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/dockerfile/docker-entrypoint.sh b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/dockerfile/docker-entrypoint.sh new file mode 100644 index 0000000000..f7b530f10c --- /dev/null +++ b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/dockerfile/docker-entrypoint.sh @@ -0,0 +1,25 @@ +#!/bin/bash + +USER=${LOCAL_USER:-"root"} + +if [[ "${USER}" != "root" ]]; then + USER_ID=${LOCAL_USER_ID:-9001} + echo ${USER} + echo ${USER_ID} + + chown ${USER_ID} /home/${USER} + useradd --shell /bin/bash -u ${USER_ID} -o -c "" -m ${USER} + usermod -a -G root ${USER} + adduser ${USER} sudo + + # user:password + echo "${USER}:123" | chpasswd + + export HOME=/home/${USER} + export PATH=/home/${USER}/.local/bin/:$PATH +else + export PATH=/root/.local/bin/:$PATH +fi + +cd $HOME +exec gosu ${USER} "$@" \ No newline at end of file diff --git a/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/docs/ppo_examples.md b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/docs/ppo_examples.md new file mode 100644 index 0000000000..5b2ff6ea88 --- /dev/null +++ b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/docs/ppo_examples.md @@ -0,0 +1,69 @@ +### Training settings + +* Base model: LLaMA2 7B +* SFT dataset: Ocra 500k +* RM datasets: Anthropic HH, LMSys, OASST, WebGPT ~ 240K+ +* PPO prompts: alpaca, Anthropic HH, OASST ~ 100K + +### Question: Please introduce the game GTA5 + +**SFT** + +Human: Please introduce the game GTA5 + +Assistant: GTA5, or Grand Theft Auto V, is an action-adventure video game developed by Rockstar North and published by Rockstar Games. It was released in 2013 for the PlayStation 3 and Xbox 360, and later in 2014 for the PlayStation 4 and Xbox One. The game is set in the fictional city of Los Santos, based in Los Angeles, and follows the story of three criminals as they plan and execute heists. + +GTA5 is an open-world game, meaning that players can explore the game's vast and detailed world at their leisure, completing missions, engaging in side activities, and interacting with other characters. The game features an extensive array of vehicles, from cars to boats to planes, and players can use them to travel around the city and engage in various forms of criminal activity. + +GTA5 is known for its immersive gameplay, stunning graphics, and rich storytelling. It has received critical acclaim and has become one of the best-selling video games of all time, with over 150 million copies sold worldwide. + +**PPO** + +Human: Please introduce the game GTA5 + +Assistant: 1. **Game Name:** Grand Theft Auto V (GTA V) + +2. **Platforms:** GTA V is available on PlayStation 3 and Xbox 360, and later released on PlayStation 4, Xbox One, and Microsoft Windows. + +3. **Genre:** Action-adventure game + +4. **Gameplay:** GTA V is an open-world game where players can explore the fictional city of Los Santos and its surrounding areas, completing missions, engaging in side activities, and interacting with non-player characters. The game features a third-person perspective and allows players to freely roam the game world, engage in combat, and complete objectives. + +5. **Storyline:** The game follows the story of three protagonists, Michael De Santa, Franklin Clinton, and Trevor Philips, as they become embroiled in a series of heists and other criminal activities. The game's storyline is considered one of the strongest aspects of GTA + +6. **Graphics and Sound:** GTA V features stunning graphics and realistic character animations, making the game world feel alive and immersive. The soundtrack and voice acting are also highly praised, with a diverse and engaging musical score that complements the game's action-packed gameplay. + +7. **Reception:** GTA V received critical acclaim upon its release, with many praising its storyline, gameplay, graphics, and sound. It is considered one of the best video games of all time and has sold over 150 million copies worldwide, making it one of the best-selling video games of all time. + +In summary, GTA V is an action-adventure game set in a massive open world, with a complex and engaging storyline, stunning graphics, and highly praised soundtrack and voice acting. It is considered one of the greatest video games of all time and continues to be a popular choice among gamers worldwide. + + +### Question: Which GPU is better, AMD or NVIDIA? + +**SFT** + +Human: Which GPU is better, AMD or NVIDIA? + +Assistant: + +The performance and suitability of a GPU (Graphics Processing Unit) depend on various factors, including the specific task or application being used, the model and generation of the GPU, and the system it is being used in. Both AMD and NVIDIA are well-known and reputable manufacturers of GPUs, and they both have a wide range of models and generations. + +In general, the performance of a GPU can vary depending on the specific task or application being used. For example, AMD's GPUs are known for their performance in gaming and video editing, while NVIDIA's GPUs are known for their performance in machine learning and deep learning. + +It is important to note that the performance of a GPU can also vary depending on the generation of the GPU. Newer generations of GPUs tend to have better performance and features than older generations. + +Ultimately, the best GPU for you will depend on your specific needs and the tasks you plan to use it for. It is recommended to research and compare the specific models and generations of AMD and NVIDIA GPUs to determine which one is the best fit for your needs. + +**PPO** + +Human: Which GPU is better, AMD or NVIDIA? + +Assistant: 1. **Performance:** In terms of raw performance, NVIDIA typically outperforms AMD in most cases, especially in gaming and video editing. + +2. **Price:** NVIDIA GPUs tend to be more expensive than AMD GPUs, although this can vary depending on the specific model and its features. + +3. **Architecture:** NVIDIA's GPU architecture is based on the CUDA (Compute Unified Device Architecture) platform, which allows for more efficient processing of parallel computing tasks. AMD's GPU architecture, on the other hand, is based on the Stream Compute architecture, which is designed for general-purpose computing. + +4. **Application Support:** NVIDIA's GPUs have better support for certain applications, particularly in the fields of deep learning and artificial intelligence. AMD's GPUs, however, have better support for certain games and video editing software. + +In summary, NVIDIA's GPUs tend to be more powerful and efficient for certain tasks, while AMD's GPUs are better suited for other tasks and have better support for certain applications. Ultimately, the choice between the two depends on the specific needs and preferences of the user. diff --git a/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/examples/scripts/docker_run.sh b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/examples/scripts/docker_run.sh new file mode 100644 index 0000000000..cc90b9e846 --- /dev/null +++ b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/examples/scripts/docker_run.sh @@ -0,0 +1,8 @@ +set -x + +PROJECT_PATH=$(cd $(dirname $0)/../../; pwd) +IMAGE_NAME="nvcr.io/nvidia/pytorch:24.07-py3" + +docker run --runtime=nvidia -it --rm --shm-size="10g" --cap-add=SYS_ADMIN \ + -v $PROJECT_PATH:/openrlhf -v $HOME/.cache:/root/.cache -v $HOME/.bash_history2:/root/.bash_history \ + $IMAGE_NAME bash \ No newline at end of file diff --git a/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/examples/scripts/nvidia_docker_install.sh b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/examples/scripts/nvidia_docker_install.sh new file mode 100644 index 0000000000..70e9a0f7f6 --- /dev/null +++ b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/examples/scripts/nvidia_docker_install.sh @@ -0,0 +1,28 @@ +set -x + +# remove old docker +sudo apt-get autoremove docker docker-ce docker-engine docker.io containerd runc +dpkg -l |grep ^rc|awk '{print $2}' |sudo xargs dpkg -P +sudo apt-get autoremove docker-ce-* +sudo rm -rf /etc/systemd/system/docker.service.d +sudo rm -rf /var/lib/docker + +# install docker +curl https://get.docker.com | sh \ + && sudo systemctl --now enable docker + +# install nvidia-docker +distribution=$(. /etc/os-release;echo $ID$VERSION_ID) \ + && curl -fsSL https://nvidia.github.io/libnvidia-container/gpgkey | sudo gpg --dearmor -o /usr/share/keyrings/nvidia-container-toolkit-keyring.gpg \ + && curl -s -L https://nvidia.github.io/libnvidia-container/$distribution/libnvidia-container.list | \ + sed 's#deb https://#deb [signed-by=/usr/share/keyrings/nvidia-container-toolkit-keyring.gpg] https://#g' | \ + sudo tee /etc/apt/sources.list.d/nvidia-container-toolkit.list + +sudo apt-get update +sudo apt-get install -y nvidia-container-toolkit +sudo nvidia-ctk runtime configure --runtime=docker + +sudo groupadd docker +sudo usermod -aG docker $USER +newgrp docker +docker ps \ No newline at end of file diff --git a/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/examples/scripts/serve_remote_rm.sh b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/examples/scripts/serve_remote_rm.sh new file mode 100644 index 0000000000..1379cc8df5 --- /dev/null +++ b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/examples/scripts/serve_remote_rm.sh @@ -0,0 +1,10 @@ +set -x + +python -m openrlhf.cli.serve_rm \ + --reward_pretrain OpenRLHF/Llama-3-8b-rm-700k \ + --port 5000 \ + --bf16 \ + --flash_attn \ + --normalize_reward \ + --max_len 8192 \ + --batch_size 16 \ No newline at end of file diff --git a/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/examples/scripts/train_conditional_llama.sh b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/examples/scripts/train_conditional_llama.sh new file mode 100644 index 0000000000..71a53637b0 --- /dev/null +++ b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/examples/scripts/train_conditional_llama.sh @@ -0,0 +1,50 @@ +set -x + +checkSuccess() { + if [[ $? != 0 ]]; then + echo "FAILED $1" + exit 1 + fi +} + +mkdir -p ./checkpoint/llama-2-8b-csft +RM_OUTPUT=./checkpoint/llama-2-8b-csft/rm.jsonl + +read -r -d '' get_rewards_commands < 0)); then + POLICY_MODEL_PATH=$MODEL_OUTPUT_PATH + fi + + read -r -d '' generate_commands <$ITER_LOG_PATH + fi +done \ No newline at end of file diff --git a/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/examples/scripts/train_knowledge_distillation.sh b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/examples/scripts/train_knowledge_distillation.sh new file mode 100644 index 0000000000..687913c70c --- /dev/null +++ b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/examples/scripts/train_knowledge_distillation.sh @@ -0,0 +1,32 @@ +set -x + +read -r -d '' training_commands <> ${JOBLOG} + +# load training commands +source ./${training_script} slurm +echo training_commands &>> ${JOBLOG} +echo $training_commands &>> ${JOBLOG} + +# master addr and port +export MASTER_ADDR=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1) +export MASTER_PORT=9901 + +srun --container-image="$IMAGE_NAME" \ + --container-mounts="$PROJECT_PATH:/openrlhf,$HOME/.cache:/root/.cache" \ + bash -c " cd /openrlhf; pip install . ; torchrun \ +--nproc_per_node $GPUS_PER_NODE --nnodes $SLURM_NNODES --node_rank $SLURM_PROCID \ +--master_addr $MASTER_ADDR --master_port $MASTER_PORT -m ${training_commands}" &>> ${JOBLOG} + +echo "$(date '+%Y-%m-%d %H:%M:%S') Job ${SLURM_JOB_ID} stopped ..." &>> ${JOBLOG} \ No newline at end of file diff --git a/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/examples/scripts/train_ppo_llama.sh b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/examples/scripts/train_ppo_llama.sh new file mode 100644 index 0000000000..d2c12185b3 --- /dev/null +++ b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/examples/scripts/train_ppo_llama.sh @@ -0,0 +1,40 @@ +set -x + +read -r -d '' training_commands < +MOUNT="$OPENRLHF_PATH:/openrlhf,$HOME/.cache:/root/.cache" +IMAGE_NAME="nvcr.io/nvidia/pytorch:24.07-py3" +RAY_VERSION=2.12.0 + +JOBLOG="$(realpath .)/train_ppo_llama_ray-$SLURM_JOB_ID.log" +echo "$(date '+%Y-%m-%d %H:%M:%S') Job ${SLURM_JOB_ID} started ..." &>> ${JOBLOG} + +# launch ray daemon +nodes=$(scontrol show hostnames "$SLURM_JOB_NODELIST") # Getting the node names +nodes_array=( $nodes ) +node_1=${nodes_array[0]} +ip=$node_1 + +port=6379 +ip_head=$ip:$port +export ip_head +echo "IP Head: $ip_head" &>> ${JOBLOG} + +echo "STARTING HEAD at $node_1" &>> ${JOBLOG} +srun --nodes=1 --ntasks=1 -w "$node_1" --container-image="$IMAGE_NAME" --container-mounts="$MOUNT" bash -c \ +&& pip install ray[default]==$RAY_VERSION \ +&& /root/.local/bin/ray start --head --node-ip-address=$ip --port=$port --block" &>> ${JOBLOG} & +sleep 10s + +worker_num=$((SLURM_JOB_NUM_NODES)) #number of nodes other than the head node +for ((i = 1; i < worker_num; i++)); do +node_i=${nodes_array[$i]} +echo "STARTING WORKER $i at $node_i" &>> ${JOBLOG} +srun --nodes=1 --ntasks=1 -w "$node_i" --container-image="$IMAGE_NAME" --container-mounts="$MOUNT" bash -c \ + && pip install ray[default]==$RAY_VERSION \ + && /root/.local/bin/ray start --address "$ip_head" --block" &>> ${JOBLOG} & +sleep 1s; +done + +sleep 30s + +# ===== submit ray job ===== +# Job start +srun --overlap --nodes=1 --ntasks=1 -w "$node_1" --container-image="$IMAGE_NAME" --container-mounts="$MOUNT" bash -c \ +"pip install ray[default]==$RAY_VERSION \ +&& /root/.local/bin/ray job submit --address=http://localhost:8265 \ + --runtime-env-json='{\"working_dir\": \"/openrlhf\", \"pip\": \"/openrlhf/requirements.txt\"}' \ + -- python3 -m openrlhf.cli.train_ppo_ray \ + --ref_num_nodes 1 \ + --ref_num_gpus_per_node 4 \ + --reward_num_nodes 1 \ + --reward_num_gpus_per_node 4 \ + --critic_num_nodes 1 \ + --critic_num_gpus_per_node 4 \ + --actor_num_nodes 1 \ + --actor_num_gpus_per_node 4 \ + --vllm_num_engines 4 \ + --vllm_tensor_parallel_size 2 \ + --colocate_critic_reward \ + --colocate_actor_ref \ + --pretrain OpenRLHF/Llama-3-8b-sft-mixture \ + --reward_pretrain OpenRLHF/Llama-3-8b-rm-mixture \ + --save_path /openrlhf/examples/checkpoint/llama3-8b-rlhf \ + --micro_train_batch_size 8 \ + --train_batch_size 128 \ + --micro_rollout_batch_size 16 \ + --rollout_batch_size 1024 \ + --max_samples 100000 \ + --max_epochs 1 \ + --prompt_max_len 1024 \ + --generate_max_len 1024 \ + --zero_stage 3 \ + --bf16 \ + --actor_learning_rate 5e-7 \ + --critic_learning_rate 9e-6 \ + --init_kl_coef 0.01 \ + --prompt_data OpenRLHF/prompt-collection-v0.1 \ + --input_key context_messages \ + --apply_chat_template \ + --normalize_reward \ + --adam_offload \ + --flash_attn \ + --packing_samples \ + --vllm_sync_backend nccl \ + --gradient_checkpointing \ + --use_wandb {wandb_token}" &>> ${JOBLOG} + +echo "$(date '+%Y-%m-%d %H:%M:%S') Job ${SLURM_JOB_ID} stopped ..." &>> ${JOBLOG} \ No newline at end of file diff --git a/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/examples/scripts/train_ppo_llama_with_remote_rm.sh b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/examples/scripts/train_ppo_llama_with_remote_rm.sh new file mode 100644 index 0000000000..6cdaac30c9 --- /dev/null +++ b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/examples/scripts/train_ppo_llama_with_remote_rm.sh @@ -0,0 +1,50 @@ +set -x + +# python -m openrlhf.cli.serve_rm \ +# --reward_pretrain OpenRLHF/Llama-3-8b-rm-700k \ +# --port 5000 \ +# --bf16 \ +# --flash_attn \ +# --normalize_reward \ +# --max_len 8192 \ +# --batch_size 16 + +ray job submit --address="http://127.0.0.1:8265" \ + --runtime-env-json='{"working_dir": "/openrlhf"}' \ + -- python3 -m openrlhf.cli.train_ppo_ray \ + --ref_num_nodes 1 \ + --ref_num_gpus_per_node 2 \ + --critic_num_nodes 1 \ + --critic_num_gpus_per_node 2 \ + --actor_num_nodes 1 \ + --actor_num_gpus_per_node 2 \ + --vllm_num_engines 2 \ + --vllm_tensor_parallel_size 2 \ + --colocate_actor_ref \ + --pretrain OpenRLHF/Llama-3-8b-sft-mixture \ + --remote_rm_url http://localhost:5000/get_reward \ + --save_path /openrlhf/examples/checkpoint/llama3-8b-rlhf \ + --micro_train_batch_size 8 \ + --train_batch_size 128 \ + --micro_rollout_batch_size 16 \ + --rollout_batch_size 1024 \ + --max_samples 100000 \ + --max_epochs 1 \ + --prompt_max_len 1024 \ + --generate_max_len 1024 \ + --zero_stage 3 \ + --bf16 \ + --actor_learning_rate 5e-7 \ + --critic_learning_rate 9e-6 \ + --init_kl_coef 0.01 \ + --prompt_data OpenRLHF/prompt-collection-v0.1 \ + --input_key context_messages \ + --apply_chat_template \ + --normalize_reward \ + --packing_samples \ + --adam_offload \ + --flash_attn \ + --gradient_checkpointing \ + --use_wandb {wandb_token} + +# --vllm_sync_backend nccl (Only for multi-nodes with vLLM 0.6.4+ or vLLM 0.4.2) \ No newline at end of file diff --git a/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/examples/scripts/train_prm_mistral.sh b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/examples/scripts/train_prm_mistral.sh new file mode 100644 index 0000000000..09192e9d71 --- /dev/null +++ b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/examples/scripts/train_prm_mistral.sh @@ -0,0 +1,34 @@ +set -x + +read -r -d '' training_commands < 0)); then + POLICY_MODEL_PATH=$MODEL_OUTPUT_PATH + fi + + read -r -d '' generate_commands <$ITER_LOG_PATH + fi +done \ No newline at end of file diff --git a/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/examples/scripts/train_rm_llama.sh b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/examples/scripts/train_rm_llama.sh new file mode 100644 index 0000000000..7e114f9a0f --- /dev/null +++ b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/examples/scripts/train_rm_llama.sh @@ -0,0 +1,31 @@ +set -x + +read -r -d '' training_commands < 1: + assert args.packing_samples, "packing_samples must be enabled when using ring attention" + + train(args) diff --git a/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/cli/train_kd.py b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/cli/train_kd.py new file mode 100644 index 0000000000..d7dca88479 --- /dev/null +++ b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/cli/train_kd.py @@ -0,0 +1,232 @@ +import argparse +import math +import os +from datetime import datetime + +from transformers.trainer import get_scheduler + +from openrlhf.datasets import SFTDataset +from openrlhf.models import Actor +from openrlhf.trainer import KDTrainer +from openrlhf.utils import blending_datasets, get_strategy, get_tokenizer + + +def train(args): + # configure strategy + strategy = get_strategy(args) + strategy.setup_distributed() + + # configure model + # load huggingface model + model = Actor( + args.pretrain, + use_flash_attention_2=args.flash_attn, + bf16=args.bf16, + load_in_4bit=args.load_in_4bit, + lora_rank=args.lora_rank, + lora_alpha=args.lora_alpha, + target_modules=args.target_modules, + lora_dropout=args.lora_dropout, + ds_config=strategy.get_ds_train_config(is_actor=True), + ) + + # load teacher model for inference + teacher_model = Actor( + args.teacher_model, + use_flash_attention_2=args.flash_attn, + bf16=args.bf16, + load_in_4bit=args.load_in_4bit, + ds_config=strategy.get_ds_eval_config(offload=args.teacher_offload), + ) + if args.teacher_offload: + teacher_model._offload = True + + # configure tokenizer + tokenizer = get_tokenizer(args.pretrain, model.model, "right", strategy, use_fast=not args.disable_fast_tokenizer) + + strategy.print(model) + + # configure optimizer + optim = strategy.create_optimizer(model, lr=args.learning_rate, betas=args.adam_betas, weight_decay=args.l2) + + # prepare for data and dataset + train_data, eval_data = blending_datasets( + args.dataset, + args.dataset_probs, + strategy, + args.seed, + max_count=args.max_samples, + train_split=args.train_split, + eval_split=args.eval_split, + ) + train_data = train_data.select(range(min(args.max_samples, len(train_data)))) + eval_data = eval_data.select(range(min(args.max_samples, len(eval_data)))) + train_dataset = SFTDataset( + train_data, + tokenizer, + args.max_len, + strategy, + pretrain_mode=args.pretrain_mode, + input_template=args.input_template, + ) + eval_dataset = SFTDataset( + eval_data, + tokenizer, + args.max_len, + strategy, + pretrain_mode=args.pretrain_mode, + input_template=args.input_template, + ) + + train_dataloader = strategy.setup_dataloader( + train_dataset, args.micro_train_batch_size, True, True, train_dataset.collate_fn + ) + eval_dataloader = strategy.setup_dataloader( + eval_dataset, args.micro_train_batch_size, True, False, eval_dataset.collate_fn + ) + + # scheduler + num_update_steps_per_epoch = len(train_dataset) // args.train_batch_size + max_steps = math.ceil(args.max_epochs * num_update_steps_per_epoch) + + scheduler = get_scheduler( + args.lr_scheduler, + optim, + num_warmup_steps=math.ceil(max_steps * args.lr_warmup_ratio), + num_training_steps=max_steps, + scheduler_specific_kwargs={"min_lr": args.learning_rate * 0.1}, + ) + + # gradient_checkpointing + if args.gradient_checkpointing: + model.gradient_checkpointing_enable( + gradient_checkpointing_kwargs={"use_reentrant": args.gradient_checkpointing_use_reentrant} + ) + + # prepare models + ((model, optim, scheduler), teacher_model) = strategy.prepare((model, optim, scheduler), teacher_model) + + # load checkpoint + consumed_samples = 0 + if args.load_checkpoint and os.path.exists(args.ckpt_path): + _, states = strategy.load_ckpt(model.model, args.ckpt_path) + consumed_samples = states["consumed_samples"] + strategy.print(f"Loaded the checkpoint: {args.ckpt_path}, consumed_samples: {consumed_samples}") + + os.makedirs(args.save_path, exist_ok=True) + + # configure Trainer + trainer = KDTrainer( + model=model, + teacher_model=teacher_model, + strategy=strategy, + optim=optim, + train_dataloader=train_dataloader, + eval_dataloader=eval_dataloader, + scheduler=scheduler, + max_norm=args.max_norm, + pretrain_mode=args.pretrain_mode, + batch_size=args.train_batch_size, + max_epochs=args.max_epochs, + tokenizer=tokenizer, + ) + + trainer.fit(args, consumed_samples, num_update_steps_per_epoch) + + # save model checkpoint after fitting on only rank0 + strategy.save_model(model, tokenizer, args.save_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Checkpoints + parser.add_argument("--save_path", type=str, default="./ckpt") + parser.add_argument("--save_steps", type=int, default=-1) + parser.add_argument("--logging_steps", type=int, default=1) + parser.add_argument("--eval_steps", type=int, default=-1) + parser.add_argument("--ckpt_path", type=str, default="./ckpt/checkpoints_kd") + parser.add_argument("--max_ckpt_num", type=int, default=3) + parser.add_argument("--max_ckpt_mem", type=int, default=1e8) + parser.add_argument("--load_checkpoint", action="store_true", default=False) + + # DeepSpeed + parser.add_argument("--micro_train_batch_size", type=int, default=8, help="batch size per GPU") + parser.add_argument("--train_batch_size", type=int, default=128, help="Global training batch size") + parser.add_argument("--max_norm", type=float, default=1.0, help="Gradient clipping") + parser.add_argument("--gradient_checkpointing", action="store_true", default=False) + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--local_rank", type=int, default=-1, help="local_rank for deepspeed") + parser.add_argument("--zero_stage", type=int, default=2, help="DeepSpeed ZeRO stage") + parser.add_argument("--bf16", action="store_true", default=False, help="Enable bfloat16") + parser.add_argument("--zpg", type=int, default=1, help="ZeRO++ max partition size") + parser.add_argument("--adam_offload", action="store_true", default=False, help="Offload Adam Optimizer") + parser.add_argument("--flash_attn", action="store_true", default=False, help="Enable FlashAttention2") + parser.add_argument("--aux_loss_coef", type=float, default=0, help="MoE balancing loss") + parser.add_argument("--grad_accum_dtype", type=str, default=None, help="Adam grad accum data type") + parser.add_argument("--overlap_comm", action="store_true", default=False) + parser.add_argument("--gradient_checkpointing_use_reentrant", action="store_true", default=False) + parser.add_argument("--disable_fast_tokenizer", action="store_true", default=False) + + # LoRA + parser.add_argument("--load_in_4bit", action="store_true", default=False) + parser.add_argument("--lora_rank", type=int, default=0) + parser.add_argument("--lora_alpha", type=int, default=16) + parser.add_argument("--target_modules", type=str, nargs="*", default="all-linear") + parser.add_argument("--lora_dropout", type=float, default=0) + + # KD + parser.add_argument("--pretrain", type=str, default=None) + parser.add_argument("--teacher_model", type=str, default=None) + parser.add_argument("--max_epochs", type=int, default=1) + parser.add_argument("--kd_coef", type=float, default=0.4) + parser.add_argument("--learning_rate", type=float, default=5e-6) + parser.add_argument("--lr_warmup_ratio", type=float, default=0.03) + parser.add_argument("--pretrain_mode", action="store_true", default=False, help="Use pretrain loss") + parser.add_argument("--lr_scheduler", type=str, default="cosine_with_min_lr") + parser.add_argument("--l2", type=float, default=0, help="weight decay loss") + parser.add_argument("--adam_betas", type=float, nargs=2, default=(0.9, 0.95), help="Betas for Adam optimizer") + parser.add_argument("--teacher_offload", action="store_true", default=False) + + # Custom dataset + parser.add_argument("--dataset", type=str, default=None) + parser.add_argument("--dataset_probs", type=str, default="1.0", help="sampling probs for datasets") + parser.add_argument("--train_split", type=str, default="train", help="train split of the HF dataset") + parser.add_argument("--eval_split", type=str, default="test", help="test split of the dataset") + + parser.add_argument("--input_key", type=str, default="input", help="JSON dataset key") + parser.add_argument("--output_key", type=str, default="output", help="JSON dataset key") + parser.add_argument("--input_template", type=str, default="User: {}\nAssistant: ") + parser.add_argument( + "--apply_chat_template", action="store_true", default=False, help="Use HF tokenizer chat template" + ) + + parser.add_argument("--max_samples", type=int, default=1e8, help="Max number of samples") + parser.add_argument("--max_len", type=int, default=2048, help="Max tokens for the samples") + + # wandb parameters + parser.add_argument("--use_wandb", type=str, default=None) + parser.add_argument("--wandb_org", type=str, default=None) + parser.add_argument("--wandb_group", type=str, default=None) + parser.add_argument("--wandb_project", type=str, default="openrlhf_train_sft") + parser.add_argument( + "--wandb_run_name", + type=str, + default="sft_%s" % datetime.now().strftime("%m%dT%H:%M"), + ) + + # TensorBoard parameters + parser.add_argument("--use_tensorboard", type=str, default=None, help="TensorBoard logging path") + + args = parser.parse_args() + + if args.input_template and "{}" not in args.input_template: + print("[Warning] {} not in args.input_template, set to None") + args.input_template = None + + if args.input_template and "\\n" in args.input_template: + print( + "[Warning] input_template contains \\n chracters instead of newline. " + "You likely want to pass $'\\n' in Bash or \"`n\" in PowerShell." + ) + + train(args) diff --git a/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/cli/train_kto.py b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/cli/train_kto.py new file mode 100644 index 0000000000..9bdbff2c33 --- /dev/null +++ b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/cli/train_kto.py @@ -0,0 +1,226 @@ +import argparse +import math +import os +from datetime import datetime + +from transformers.trainer import get_scheduler + +from openrlhf.datasets import UnpairedPreferenceDataset +from openrlhf.models import Actor +from openrlhf.trainer import KTOTrainer +from openrlhf.utils import blending_datasets, get_strategy, get_tokenizer + + +def train(args): + # configure strategy + strategy = get_strategy(args) + strategy.setup_distributed() + + # configure model + # load huggingface model + model = Actor( + args.pretrain, + use_flash_attention_2=args.flash_attn, + bf16=args.bf16, + load_in_4bit=args.load_in_4bit, + lora_rank=args.lora_rank, + lora_alpha=args.lora_alpha, + target_modules=args.target_modules, + lora_dropout=args.lora_dropout, + ds_config=strategy.get_ds_train_config(is_actor=True), + ) + + # configure tokenizer + tokenizer = get_tokenizer(args.pretrain, model.model, "right", strategy, use_fast=not args.disable_fast_tokenizer) + strategy.print(model) + + # load weights for ref model + ref_model = Actor( + args.pretrain, + use_flash_attention_2=args.flash_attn, + bf16=args.bf16, + load_in_4bit=args.load_in_4bit, + ds_config=strategy.get_ds_eval_config(offload=args.ref_offload), + ) + if args.ref_offload: + ref_model._offload = True + get_tokenizer(args.pretrain, ref_model.model, "right", strategy, use_fast=not args.disable_fast_tokenizer) + + # gradient_checkpointing + if args.gradient_checkpointing: + model.gradient_checkpointing_enable( + gradient_checkpointing_kwargs={"use_reentrant": args.gradient_checkpointing_use_reentrant} + ) + + # configure optimizer + optim = strategy.create_optimizer(model, lr=args.learning_rate, betas=args.adam_betas, weight_decay=args.l2) + + # prepare for data and dataset + train_data, eval_data = blending_datasets( + args.dataset, + args.dataset_probs, + strategy, + args.seed, + max_count=args.max_samples, + stopping_strategy="all_exhausted", + train_split=args.train_split, + eval_split=args.eval_split, + ) + train_data = train_data.select(range(min(args.max_samples, len(train_data)))) + eval_data = eval_data.select(range(min(args.max_samples, len(eval_data)))) + + train_dataset = UnpairedPreferenceDataset( + train_data, tokenizer, args.max_len, strategy, input_template=args.input_template + ) + eval_dataset = UnpairedPreferenceDataset( + eval_data, tokenizer, args.max_len, strategy, input_template=args.input_template + ) + train_dataloader = strategy.setup_dataloader( + train_dataset, + args.micro_train_batch_size, + True, + True, + train_dataset.collate_fn, + ) + eval_dataloader = strategy.setup_dataloader( + eval_dataset, args.micro_train_batch_size, True, False, eval_dataset.collate_fn + ) + + # scheduler + num_update_steps_per_epoch = len(train_dataset) // args.train_batch_size + max_steps = math.ceil(args.max_epochs * num_update_steps_per_epoch) + + scheduler = get_scheduler( + "cosine_with_min_lr", + optim, + num_warmup_steps=math.ceil(max_steps * args.lr_warmup_ratio), + num_training_steps=max_steps, + scheduler_specific_kwargs={"min_lr": args.learning_rate * 0.1}, + ) + + # strategy prepare + ((model, optim, scheduler), ref_model) = strategy.prepare((model, optim, scheduler), ref_model) + + # load checkpoint + consumed_samples = 0 + if args.load_checkpoint and os.path.exists(args.ckpt_path): + _, states = strategy.load_ckpt(model.model, args.ckpt_path) + consumed_samples = states["consumed_samples"] + strategy.print(f"Loaded the checkpoint: {args.ckpt_path}, consumed_samples: {consumed_samples}") + + os.makedirs(args.save_path, exist_ok=True) + + trainer = KTOTrainer( + model=model, + ref_model=ref_model, + tokenizer=tokenizer, + strategy=strategy, + optim=optim, + train_dataloader=train_dataloader, + eval_dataloader=eval_dataloader, + scheduler=scheduler, + max_norm=args.max_norm, + beta=args.beta, + max_epochs=args.max_epochs, + ) + trainer.fit(args, consumed_samples, num_update_steps_per_epoch) + + # save model checkpoint after fitting on only rank0 + strategy.save_model(model, tokenizer, args.save_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Checkpoints + parser.add_argument("--save_path", type=str, default="./ckpt") + parser.add_argument("--save_steps", type=int, default=-1) + parser.add_argument("--logging_steps", type=int, default=1) + parser.add_argument("--eval_steps", type=int, default=-1) + parser.add_argument("--ckpt_path", type=str, default="./ckpt/checkpoints_kto") + parser.add_argument("--max_ckpt_num", type=int, default=3) + parser.add_argument("--max_ckpt_mem", type=int, default=1e8) + parser.add_argument("--load_checkpoint", action="store_true", default=False) + + # DeepSpeed + parser.add_argument("--max_norm", type=float, default=1.0, help="Gradient clipping") + parser.add_argument("--micro_train_batch_size", type=int, default=8, help="batch size per GPU") + parser.add_argument("--train_batch_size", type=int, default=128, help="Global training batch size") + parser.add_argument("--gradient_checkpointing", action="store_true", default=False) + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--disable_fast_tokenizer", action="store_true", default=False) + parser.add_argument("--local_rank", type=int, default=-1, help="local_rank for deepspeed") + parser.add_argument("--zero_stage", type=int, default=2, help="DeepSpeed ZeRO stage") + parser.add_argument("--bf16", action="store_true", default=False, help="Enable bfloat16") + parser.add_argument("--ref_offload", action="store_true", default=False) + parser.add_argument("--zpg", type=int, default=1, help="ZeRO++ max partition size") + parser.add_argument("--adam_offload", action="store_true", default=False, help="Offload Adam Optimizer") + parser.add_argument("--flash_attn", action="store_true", default=False, help="Enable FlashAttention2") + parser.add_argument("--grad_accum_dtype", type=str, default=None, help="Adam grad accum data type") + parser.add_argument("--overlap_comm", action="store_true", default=False) + parser.add_argument("--gradient_checkpointing_use_reentrant", action="store_true", default=False) + + # LoRA + parser.add_argument("--load_in_4bit", action="store_true", default=False) + parser.add_argument("--lora_rank", type=int, default=0) + parser.add_argument("--lora_alpha", type=int, default=16) + parser.add_argument("--target_modules", type=str, nargs="*", default="all-linear") + parser.add_argument("--lora_dropout", type=float, default=0) + + # KTO + parser.add_argument("--max_epochs", type=int, default=1) + parser.add_argument("--beta", type=float, default=0.01) + parser.add_argument("--pretrain", type=str, default="bigscience/bloomz-1b7") + parser.add_argument("--learning_rate", type=float, default=1e-5) + parser.add_argument("--lr_warmup_ratio", type=float, default=0.03) + parser.add_argument("--l2", type=float, default=0.0, help="weight decay loss") + parser.add_argument("--aux_loss_coef", type=float, default=0, help="MoE balancing loss") + parser.add_argument("--adam_betas", type=float, nargs=2, default=(0.9, 0.95), help="Betas for Adam optimizer") + parser.add_argument("--desirable_loss_weight", type=float, default=1.0, help="Loss weight for desirable samples") + parser.add_argument( + "--undesirable_loss_weight", type=float, default=1.0, help="Loss weight for undesirable samples" + ) + + # Custom dataset + parser.add_argument("--dataset", type=str, default="Dahoas/full-hh-rlhf") + parser.add_argument("--dataset_probs", type=str, default="1.0", help="sampling probs for datasets") + parser.add_argument("--train_split", type=str, default="train", help="train split of the HF dataset") + parser.add_argument("--eval_split", type=str, default="test", help="test split of the dataset") + + parser.add_argument("--input_key", type=str, default="input", help="JSON dataset key") + parser.add_argument("--output_key", type=str, default=None, help="JSON dataset key") + parser.add_argument("--label_key", type=str, default="label") + + parser.add_argument("--input_template", type=str, default="User: {}\nAssistant: ") + parser.add_argument( + "--apply_chat_template", action="store_true", default=False, help="Use HF tokenizer chat template" + ) + parser.add_argument("--max_len", type=int, default=2048, help="Max tokens for the samples") + parser.add_argument("--max_samples", type=int, default=1e8, help="Max number of samples") + + # wandb parameters + parser.add_argument("--use_wandb", type=str, default=None) + parser.add_argument("--wandb_org", type=str, default=None) + parser.add_argument("--wandb_group", type=str, default=None) + parser.add_argument("--wandb_project", type=str, default="openrlhf_train_dpo") + parser.add_argument( + "--wandb_run_name", + type=str, + default="rm_%s" % datetime.now().strftime("%m%dT%H:%M"), + ) + + # TensorBoard parameters + parser.add_argument("--use_tensorboard", type=str, default=None, help="TensorBoard logging path") + + args = parser.parse_args() + + if args.input_template and "{}" not in args.input_template: + print("[Warning] {} not in args.input_template, set to None") + args.input_template = None + + if args.input_template and "\\n" in args.input_template: + print( + "[Warning] input_template contains \\n chracters instead of newline. " + "You likely want to pass $'\\n' in Bash or \"`n\" in PowerShell." + ) + + train(args) diff --git a/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/cli/train_ppo.py b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/cli/train_ppo.py new file mode 100644 index 0000000000..2fc03bac05 --- /dev/null +++ b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/cli/train_ppo.py @@ -0,0 +1,443 @@ +import argparse +import itertools +import math +import os +from datetime import datetime + +import torch +from transformers.trainer import get_scheduler + +from openrlhf.datasets import PromptDataset, SFTDataset +from openrlhf.models import Actor, get_llm_for_sequence_regression +from openrlhf.trainer import PPOTrainer +from openrlhf.utils import blending_datasets, get_strategy, get_tokenizer + + +def train(args): + # configure strategy + strategy = get_strategy(args) + strategy.setup_distributed() + + # configure model + # load huggingface model + actor = Actor( + args.pretrain, + use_flash_attention_2=args.flash_attn, + bf16=args.bf16, + load_in_4bit=args.load_in_4bit, + lora_rank=args.lora_rank, + lora_alpha=args.lora_alpha, + target_modules=args.target_modules, + lora_dropout=args.lora_dropout, + ds_config=strategy.get_ds_train_config(is_actor=True), + ) + + if args.actor_init_on_gpu: + actor = actor.to(torch.cuda.current_device()) + + if args.critic_pretrain: + critic = get_llm_for_sequence_regression( + args.critic_pretrain, + "critic", + normalize_reward=args.normalize_reward, + use_flash_attention_2=args.flash_attn, + bf16=args.bf16, + load_in_4bit=args.load_in_4bit, + lora_rank=args.lora_rank, + lora_alpha=args.lora_alpha, + target_modules=args.target_modules, + lora_dropout=args.lora_dropout, + ds_config=strategy.get_ds_train_config(is_actor=False), + value_head_prefix=args.value_head_prefix, + init_value_head=strategy.args.pretrain == strategy.args.critic_pretrain, + ) + else: + critic = None + + if not args.remote_rm_url: + reward_model = get_llm_for_sequence_regression( + args.reward_pretrain, + "reward", + normalize_reward=args.normalize_reward, + use_flash_attention_2=args.flash_attn, + bf16=args.bf16, + load_in_4bit=args.load_in_4bit, + ds_config=strategy.get_ds_train_config(is_actor=False), + value_head_prefix=args.value_head_prefix, + ) + else: + reward_model = None + + strategy.print("reward normalization status: {}".format(args.normalize_reward)) + if reward_model: + strategy.print("mean: {}, std {}".format(reward_model.mean, reward_model.std)) + + strategy.print(actor) + strategy.print(critic) + + # configure tokenizer + tokenizer = get_tokenizer(args.pretrain, actor.model, "left", strategy, use_fast=not args.disable_fast_tokenizer) + + # load weights for reference actor + initial_model = Actor( + args.pretrain, + use_flash_attention_2=args.flash_attn, + bf16=args.bf16, + load_in_4bit=args.load_in_4bit, + ds_config=strategy.get_ds_eval_config(offload=False), + ) + + if args.enable_ema: + ema_model = Actor( + args.pretrain, + use_flash_attention_2=args.flash_attn, + bf16=args.bf16, + load_in_4bit=args.load_in_4bit, + ds_config=strategy.get_ds_eval_config(offload=True), + ) + else: + ema_model = None + + # gradient_checkpointing + if args.gradient_checkpointing: + actor.gradient_checkpointing_enable( + gradient_checkpointing_kwargs={"use_reentrant": args.gradient_checkpointing_use_reentrant} + ) + if critic is not None: + critic.gradient_checkpointing_enable( + gradient_checkpointing_kwargs={"use_reentrant": args.gradient_checkpointing_use_reentrant} + ) + + # configure optimizer + actor_optim = strategy.create_optimizer( + actor, lr=args.actor_learning_rate, betas=args.adam_betas, weight_decay=args.l2 + ) + if args.critic_pretrain: + critic_optim = strategy.create_optimizer( + critic, lr=args.critic_learning_rate, betas=args.adam_betas, weight_decay=args.l2 + ) + else: + critic_optim = None + + # prepare datasets + prompts_data = blending_datasets( + args.prompt_data, + args.prompt_data_probs, + strategy, + args.seed, + max_count=args.max_samples, + return_eval=False, + train_split=args.prompt_split, + ) + prompts_data = prompts_data.select(range(min(args.max_samples, len(prompts_data)))) + prompts_dataset = PromptDataset(prompts_data, tokenizer, strategy, input_template=args.input_template) + + if args.pretrain_data: + pretrain_data = blending_datasets( + args.pretrain_data, + args.pretrain_data_probs, + strategy, + args.seed, + return_eval=False, + train_split=args.pretrain_split, + ) + pretrain_max_len = args.max_len if args.max_len else args.prompt_max_len + args.generate_max_len + pretrain_dataset = SFTDataset( + pretrain_data.select( + range(min(len(pretrain_data), args.max_epochs * len(prompts_dataset) * args.n_samples_per_prompt)) + ), + tokenizer, + pretrain_max_len, + strategy, + pretrain_mode=True, + ) + + # prepare dataloader + prompts_dataloader = strategy.setup_dataloader( + prompts_dataset, args.rollout_batch_size // strategy.world_size, True, True + ) + if args.pretrain_data: + pretrain_dataloader = itertools.cycle( + iter( + strategy.setup_dataloader( + pretrain_dataset, + args.micro_train_batch_size, + True, + True, + pretrain_dataset.collate_fn, + ) + ) + ) + else: + pretrain_dataloader = None + + # configure scheduler + num_update_steps_per_episodes = ( + len(prompts_dataset) * args.n_samples_per_prompt // args.train_batch_size * args.max_epochs + ) + max_steps = math.ceil(args.num_episodes * num_update_steps_per_episodes) + + actor_scheduler = get_scheduler( + "cosine_with_min_lr", + actor_optim, + num_warmup_steps=math.ceil(max_steps * args.lr_warmup_ratio), + num_training_steps=max_steps, + scheduler_specific_kwargs={"min_lr": args.actor_learning_rate * 0.1}, + ) + + if args.critic_pretrain: + critic_scheduler = get_scheduler( + "cosine_with_min_lr", + critic_optim, + num_warmup_steps=math.ceil(max_steps * args.lr_warmup_ratio), + num_training_steps=max_steps, + scheduler_specific_kwargs={"min_lr": args.critic_learning_rate * 0.1}, + ) + else: + critic_scheduler = None + + # prepare models/optimizers... + ( + (actor, actor_optim, actor_scheduler), + (critic, critic_optim, critic_scheduler), + reward_model, + initial_model, + ) = strategy.prepare( + (actor, actor_optim, actor_scheduler), + (critic, critic_optim, critic_scheduler), + reward_model, + initial_model, + is_rlhf=True, + ) + + if ema_model: + ema_model._offload = True + ema_model = strategy.prepare(ema_model, is_rlhf=True) + + # load checkpoint + consumed_samples = 0 + if args.load_checkpoint and os.path.exists(os.path.join(args.ckpt_path, "_actor")): + _, states = strategy.load_ckpt(actor.model, os.path.join(args.ckpt_path, "_actor")) + if args.critic_pretrain: + strategy.load_ckpt(critic, os.path.join(args.ckpt_path, "_critic")) + consumed_samples = states["consumed_samples"] + strategy.print(f"Loaded the checkpoint: {args.ckpt_path}, consumed_samples: {consumed_samples}") + + os.makedirs(args.save_path, exist_ok=True) + + # configure Trainer + trainer = PPOTrainer( + strategy, + actor, + critic, + reward_model, + initial_model, + ema_model, + actor_optim, + critic_optim, + actor_scheduler, + critic_scheduler, + max_epochs=args.max_epochs, + micro_train_batch_size=args.micro_train_batch_size, + micro_rollout_batch_size=args.micro_rollout_batch_size, + gradient_checkpointing=args.gradient_checkpointing, + tokenizer=tokenizer, + prompt_max_len=args.prompt_max_len, + value_clip=args.value_clip, + eps_clip=args.eps_clip, + gamma=args.gamma, + lambd=args.lambd, + init_kl_coef=args.init_kl_coef, + kl_target=args.kl_target, + ema_beta=0.992, + ptx_coef=args.ptx_coef, + max_norm=args.max_norm, + # fro GPT generation + do_sample=True, + max_new_tokens=args.generate_max_len, + max_length=args.max_len, + temperature=args.temperature, + top_p=args.top_p, + pad_token_id=tokenizer.pad_token_id, + eos_token_id=tokenizer.eos_token_id, + # remote reward model + remote_rm_url=args.remote_rm_url, + save_hf_ckpt=args.save_hf_ckpt, + disable_ds_ckpt=args.disable_ds_ckpt, + ) + + trainer.fit(args, prompts_dataloader, pretrain_dataloader, consumed_samples, num_update_steps_per_episodes) + + # save model checkpoint after fitting on only rank0 + strategy.save_model( + ema_model if args.enable_ema else actor, + tokenizer, + args.save_path, + ) + + if args.critic_pretrain and args.save_value_network: + strategy.save_model( + critic, + tokenizer, + args.save_path + "_critic", + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Checkpoint + parser.add_argument("--save_path", type=str, default="./ckpt") + parser.add_argument("--save_steps", type=int, default=-1) + parser.add_argument("--save_hf_ckpt", action="store_true", default=False) + parser.add_argument("--disable_ds_ckpt", action="store_true", default=False) + parser.add_argument("--logging_steps", type=int, default=1) + parser.add_argument("--eval_steps", type=int, default=-1) + parser.add_argument("--ckpt_path", type=str, default="./ckpt/checkpoints_ppo") + parser.add_argument("--max_ckpt_num", type=int, default=3) + parser.add_argument("--max_ckpt_mem", type=int, default=1e8) + parser.add_argument("--load_checkpoint", action="store_true", default=False) + + # PPO + parser.add_argument("--num_episodes", type=int, default=1) + parser.add_argument("--rollout_batch_size", type=int, default=512) + parser.add_argument("--micro_rollout_batch_size", type=int, default=8) + parser.add_argument("--max_epochs", type=int, default=1) + parser.add_argument("--prompt_max_len", type=int, default=1024, help="Max tokens for each prompt") + parser.add_argument("--generate_max_len", type=int, default=1024, help="Max tokens to generate in PPO") + parser.add_argument("--max_len", type=int, default=None, help="deprecated max_len") + parser.add_argument("--max_samples", type=int, default=1000000) + parser.add_argument("--max_norm", type=float, default=1.0, help="Gradient clipping") + parser.add_argument("--l2", type=float, default=0.0, help="weight decay loss") + parser.add_argument("--ptx_coef", type=float, default=0.05, help="PPO-ptx loss coef") + parser.add_argument("--eps_clip", type=float, default=0.2, help="PPO clip range") + parser.add_argument("--value_clip", type=float, default=0.2, help="PPO value clip range") + parser.add_argument("--lambd", type=float, default=1.0, help="PPO GAE lambd") + parser.add_argument("--gamma", type=float, default=1, help="PPO GAE gamma") + parser.add_argument("--micro_train_batch_size", type=int, default=4, help="batch size per GPU") + parser.add_argument("--train_batch_size", type=int, default=128, help="Global training batch size") + parser.add_argument("--normalize_reward", action="store_true", default=False, help="Enable Reward Normazation") + parser.add_argument("--top_p", type=float, default=1.0) + parser.add_argument("--temperature", type=float, default=1.0) + parser.add_argument("--freezing_actor_steps", type=int, default=-1, help="Used for critic initialization") + parser.add_argument( + "--n_samples_per_prompt", type=int, default=1, help="number of responses for each prompt in generation" + ) + parser.add_argument("--save_value_network", action="store_true", default=False, help="Save critic model") + parser.add_argument("--actor_learning_rate", type=float, default=1e-6) + parser.add_argument("--critic_learning_rate", type=float, default=9e-6) + parser.add_argument("--lr_warmup_ratio", type=float, default=0.03) + parser.add_argument("--kl_target", type=float, default=None) + parser.add_argument("--init_kl_coef", type=float, default=0.01, help="KL penalty in PPO") + parser.add_argument( + "--use_kl_estimator_k3", + action="store_true", + default=False, + help=( + "Use the k3 estimator in http://joschu.net/blog/kl-approx.html" + "to ensure the KL divergence calculated is non-negative" + ), + ) + parser.add_argument("--adam_betas", type=float, nargs=2, default=(0.9, 0.95), help="Betas for Adam optimizer") + parser.add_argument("--reward_clip_range", type=float, nargs=2, default=(-10, 10), help="Reward clip range") + + # DeepSpeed + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--local_rank", type=int, default=-1, help="local_rank for deepspeed") + parser.add_argument("--zero_stage", type=int, default=2, help="DeepSpeed ZeRO stage") + parser.add_argument("--gradient_checkpointing", action="store_true", default=False) + parser.add_argument("--bf16", action="store_true", default=False, help="Enable bfloat16") + parser.add_argument("--enable_ema", action="store_true", help="Enable EMA checkpoint for the model.") + parser.add_argument("--zpg", type=int, default=1, help="ZeRO++ max partition size") + parser.add_argument("--adam_offload", action="store_true", default=False, help="Offload Adam Optimizer") + parser.add_argument("--actor_init_on_gpu", action="store_true", default=False) + parser.add_argument("--flash_attn", action="store_true", default=False, help="Enable FlashAttention2") + parser.add_argument("--aux_loss_coef", type=float, default=0, help="MoE balancing loss") + parser.add_argument("--grad_accum_dtype", type=str, default=None, help="Adam grad accum data type") + parser.add_argument("--overlap_comm", action="store_true", default=False) + parser.add_argument("--gradient_checkpointing_use_reentrant", action="store_true", default=False) + parser.add_argument("--disable_fast_tokenizer", action="store_true", default=False) + + # Reinforce + parser.add_argument( + "--advantage_estimator", + type=str, + choices=["gae", "reinforce", "rloo"], + default="gae", + help="Choose advantage estimation method: gae, reinforce, rloo", + ) + + # LoRA + parser.add_argument("--load_in_4bit", action="store_true", default=False) + parser.add_argument("--lora_rank", type=int, default=0) + parser.add_argument("--lora_alpha", type=int, default=16) + parser.add_argument("--target_modules", type=str, nargs="*", default="all-linear") + parser.add_argument("--lora_dropout", type=float, default=0) + + # Models + parser.add_argument("--pretrain", type=str, default=None, help="HF model name or path") + parser.add_argument("--reward_pretrain", type=str, default=None, help="HF model name or path") + parser.add_argument("--remote_rm_url", type=str, default=None, help="remote RM API") + parser.add_argument("--critic_pretrain", type=str, default=None, help="HF model name or path") + parser.add_argument("--value_head_prefix", type=str, default="score") + + # Custom dataset + parser.add_argument("--prompt_data", type=str, default=None, help="HF dataset name or path") + parser.add_argument( + "--prompt_data_probs", + type=str, + default="1.0", + help="sampling probs for datasets", + ) + parser.add_argument("--prompt_split", type=str, default="train") + parser.add_argument("--pretrain_data", type=str, default=None, help="HF dataset name or path") + parser.add_argument( + "--pretrain_data_probs", + type=str, + default="1.0", + help="sampling probs for datasets", + ) + parser.add_argument("--pretrain_split", type=str, default="train") + parser.add_argument("--input_key", type=str, default="input", help="JSON dataset key") + parser.add_argument("--input_template", type=str, default=None) + parser.add_argument( + "--apply_chat_template", action="store_true", default=False, help="Use HF tokenizer chat template" + ) + + # wandb parameters + parser.add_argument("--use_wandb", type=str, default=None) + parser.add_argument("--wandb_org", type=str, default=None) + parser.add_argument("--wandb_group", type=str, default=None) + parser.add_argument("--wandb_project", type=str, default="openrlhf_train_ppo") + parser.add_argument( + "--wandb_run_name", + type=str, + default="ppo_%s" % datetime.now().strftime("%m%dT%H:%M"), + ) + + # TensorBoard parameters + parser.add_argument("--use_tensorboard", type=str, default=None, help="TensorBoard logging path") + + args = parser.parse_args() + + if args.advantage_estimator not in ["gae"]: + args.critic_pretrain = None + elif args.critic_pretrain is None: + if not args.remote_rm_url: + args.critic_pretrain = args.reward_pretrain + else: + args.critic_pretrain = args.pretrain + + if args.advantage_estimator == "rloo": + assert args.n_samples_per_prompt > 1, "RLOO requires n_samples_per_prompt > 1" + + if args.input_template and "{}" not in args.input_template: + print("[Warning] {} not in args.input_template, set to None") + args.input_template = None + + if args.input_template and "\\n" in args.input_template: + print( + "[Warning] input_template contains \\n chracters instead of newline. " + "You likely want to pass $'\\n' in Bash or \"`n\" in PowerShell." + ) + + train(args) diff --git a/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/cli/train_ppo_ray.py b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/cli/train_ppo_ray.py new file mode 100644 index 0000000000..245b855314 --- /dev/null +++ b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/cli/train_ppo_ray.py @@ -0,0 +1,400 @@ +import argparse +from datetime import datetime +from typing import List + +import ray +import torch +from ray.util.placement_group import placement_group + +from openrlhf.trainer.ray import ( + ActorModelRayActor, + CriticModelRayActor, + PPORayActorGroup, + ReferenceModelRayActor, + RewardModelRayActor, + create_vllm_engines, +) +from openrlhf.utils import get_strategy + + +# NOTE: reward function for multiple reward models, replace this with your own function! +def reward_fn(rewards: List[torch.Tensor]): + return torch.stack(rewards).sum(dim=0) + + +def _validate_args(args): + actor_world_size = args.actor_num_nodes * args.actor_num_gpus_per_node + + assert ( + args.rollout_batch_size % actor_world_size == 0 + ), f"rollout_bach_size must be divisible by actor_world_size, got {args.rollout_batch_size} and {actor_world_size}" + + assert args.zero_stage != 3 or args.vllm_num_engines > 0, f"ZeRO-3 is only supported when vLLM enabled" + + if args.vllm_num_engines > 0: + assert ( + actor_world_size % args.vllm_num_engines == 0 + ), f"actor_world_size must be divisible by vllm_num_engines, got {actor_world_size} and {args.vllm_num_engines}" + + if args.critic_pretrain: + critic_world_size = args.critic_num_nodes * args.critic_num_gpus_per_node + assert ( + actor_world_size % critic_world_size == 0 + ), f"actor_world_size must be divisible by critic_world_size, got {actor_world_size} and {critic_world_size}" + + +def train(args): + _validate_args(args) + + # configure strategy + strategy = get_strategy(args) + + # if colocated, create placement group for actor and ref model explicitly. + pg = None + if args.colocate_actor_ref: + assert ( + args.actor_num_nodes == args.ref_num_nodes and args.actor_num_gpus_per_node == args.ref_num_gpus_per_node + ), f"num_nodes and num_gpus_per_node must be the same when colocate actor and ref model." + + bundles = [ + {"GPU": args.actor_num_gpus_per_node, "CPU": args.actor_num_gpus_per_node} + for _ in range(args.actor_num_nodes) + ] + pg = placement_group(bundles, strategy="STRICT_SPREAD") + ray.get(pg.ready()) + + # NOTE(wuxibin): Why don't we allocate 0.5 gpu for each actor when colocate models? + # Say we have 1 node with 4 GPUs, and num_gpus_per_node for each model is 4. + # If we allocate 0.5 gpu for both actor and ref model, then gpu allocation is + # |actor|actor|actor|actor| ref | ref | ref | ref | + # |GPU0 |GPU0 |GPU1 |GPU1 | GPU2 | GPU2 | GPU3 | GPU3 | + # + # So 0.75/0.25 gpu is a tricky to let Ray spread all models evenly on all gpus. + # |actor| ref |actor| ref |actor| ref |actor|ref | + # |GPU0 | GPU0 |GPU1 | GPU1 |GPU2 | GPU2 |GPU3 | GPU3 | + actor_model = PPORayActorGroup( + args.actor_num_nodes, + args.actor_num_gpus_per_node, + ActorModelRayActor, + pg=pg, + num_gpus_per_actor=0.75 if pg else 1, + ) + + ref_model = PPORayActorGroup( + args.ref_num_nodes, + args.ref_num_gpus_per_node, + ReferenceModelRayActor, + pg=pg, + num_gpus_per_actor=0.25 if pg else 1, + ) + + # if colocated, create placement group for critic and reward model explicitly. + pg = None + if args.critic_pretrain and args.colocate_critic_reward: + assert ( + args.critic_num_nodes == args.reward_num_nodes + and args.critic_num_gpus_per_node == args.reward_num_gpus_per_node + ), f"num_nodes and num_gpus_per_node must be the same when colocate critic and reward model." + + bundles = [ + {"GPU": args.critic_num_gpus_per_node, "CPU": args.critic_num_gpus_per_node} + for _ in range(args.critic_num_nodes) + ] + pg = placement_group(bundles, strategy="STRICT_SPREAD") + ray.get(pg.ready()) + + if args.critic_pretrain: + critic_model = PPORayActorGroup( + args.critic_num_nodes, + args.critic_num_gpus_per_node, + CriticModelRayActor, + pg=pg, + num_gpus_per_actor=0.75 if pg else 1, + ) + else: + critic_model = None + + # multiple reward models + if not args.remote_rm_url: + reward_pretrains = args.reward_pretrain.split(",") + reward_models = [] + for _ in reward_pretrains: + reward_models.append( + PPORayActorGroup( + args.reward_num_nodes, + args.reward_num_gpus_per_node, + RewardModelRayActor, + pg=pg, + num_gpus_per_actor=0.25 if pg else 1, + ) + ) + else: + reward_models = None + + # init reference/reward/actor model + refs = [] + refs.extend(ref_model.async_init_model_from_pretrained(strategy, args.pretrain)) + refs.extend(actor_model.async_init_model_from_pretrained(strategy, args.pretrain)) + if not args.remote_rm_url: + for reward_model, reward_pretrain in zip(reward_models, reward_pretrains): + refs.extend(reward_model.async_init_model_from_pretrained(strategy, reward_pretrain)) + + # init vLLM engine for text generation + vllm_engines = None + if args.vllm_num_engines is not None and args.vllm_num_engines > 0: + max_len = args.max_len if args.max_len else args.prompt_max_len + args.generate_max_len + vllm_engines = create_vllm_engines( + args.vllm_num_engines, + args.vllm_tensor_parallel_size, + args.pretrain, + args.seed, + args.enable_prefix_caching, + args.enforce_eager, + max_len, + ) + + ray.get(refs) + + if args.critic_pretrain: + # critic scheduler initialization depends on max_step, so we have to init critic after actor + # TODO: use first reward model as critic model + max_steps = ray.get(actor_model._actor_handlers[0].max_steps.remote()) + refs.extend(critic_model.async_init_model_from_pretrained(strategy, args.critic_pretrain, max_steps)) + ray.get(refs) + + # train actor and critic mdoel + refs = actor_model.async_fit_actor_model( + critic_model, ref_model, reward_models, args.remote_rm_url, reward_fn=reward_fn, vllm_engines=vllm_engines + ) + ray.get(refs) + + # save model + ray.get(actor_model.async_save_model()) + + if args.critic_pretrain and args.save_value_network: + ray.get(critic_model.async_save_model()) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Ray and vLLM + parser.add_argument("--ref_num_nodes", type=int, default=1, help="number of nodes for reference") + parser.add_argument("--ref_num_gpus_per_node", type=int, default=8, help="number of gpus per node for reference") + parser.add_argument("--reward_num_nodes", type=int, default=1, help="number of nodes for reward model") + parser.add_argument( + "--reward_num_gpus_per_node", type=int, default=8, help="number of gpus per node for reward model" + ) + parser.add_argument( + "--colocate_actor_ref", + action="store_true", + default=False, + help="whether to colocate reference and actor model, if true, they will share same gpus.", + ) + + parser.add_argument("--actor_num_nodes", type=int, default=1, help="number of nodes for actor") + parser.add_argument("--actor_num_gpus_per_node", type=int, default=8, help="number of gpus per node for actor") + parser.add_argument("--critic_num_nodes", type=int, default=1, help="number of nodes for critic") + parser.add_argument("--critic_num_gpus_per_node", type=int, default=8, help="number of gpus per node for critic") + parser.add_argument( + "--colocate_critic_reward", + action="store_true", + default=False, + help="whether to colocate critic and reward model, if true, they will share same gpus.", + ) + + # optional vLLM for text generation + parser.add_argument( + "--vllm_num_engines", type=int, default=None, help="number of vLLM Engines, set to 0 to disable vLLM" + ) + parser.add_argument( + "--vllm_tensor_parallel_size", + type=int, + default=1, + help="tensor parallel size of vLLM Engine for multi-GPU inference", + ) + parser.add_argument("--vllm_sync_backend", type=str, default="nccl", help="DeepSpeed -> vLLM weight sync backend") + parser.add_argument("--vllm_sync_with_ray", action="store_true", default=False) + parser.add_argument("--enable_prefix_caching", action="store_true", default=False) + parser.add_argument("--enforce_eager", action="store_true", default=False, help="Disable CUDA graph in vLLM") + + # Checkpoints + parser.add_argument("--eval_steps", type=int, default=-1) + parser.add_argument("--save_steps", type=int, default=-1) + parser.add_argument("--logging_steps", type=int, default=1) + parser.add_argument("--ckpt_path", type=str, default="./ckpt/checkpoints_ppo_ray") + parser.add_argument("--save_hf_ckpt", action="store_true", default=False) + parser.add_argument("--disable_ds_ckpt", action="store_true", default=False) + parser.add_argument("--max_ckpt_num", type=int, default=3) + parser.add_argument("--max_ckpt_mem", type=int, default=1e8) + parser.add_argument("--load_checkpoint", action="store_true", default=False) + + # DeepSpeed + parser.add_argument("--local_rank", type=int, default=-1, help="local_rank for deepspeed") + parser.add_argument("--zero_stage", type=int, default=2, help="DeepSpeed ZeRO stage") + parser.add_argument("--gradient_checkpointing", action="store_true", default=False) + parser.add_argument("--bf16", action="store_true", default=False, help="Enable bfloat16") + ## Make EMA as an optional feature + parser.add_argument("--enable_ema", action="store_true", help="Enable EMA checkpoint for the model.") + parser.add_argument("--zpg", type=int, default=1, help="ZeRO++ max partition size") + parser.add_argument("--adam_offload", action="store_true", default=False, help="Offload Adam Optimizer") + parser.add_argument("--actor_init_on_gpu", action="store_true", default=False) + parser.add_argument("--flash_attn", action="store_true", default=False, help="Enable FlashAttention2") + parser.add_argument("--grad_accum_dtype", type=str, default=None, help="Adam grad accum data type") + parser.add_argument("--overlap_comm", action="store_true", default=False) + parser.add_argument("--gradient_checkpointing_use_reentrant", action="store_true", default=False) + parser.add_argument("--disable_fast_tokenizer", action="store_true", default=False) + + # packing samples using Flash Attention2 + parser.add_argument("--packing_samples", action="store_true", default=False) + + # LoRA + parser.add_argument("--load_in_4bit", action="store_true", default=False) + parser.add_argument("--lora_rank", type=int, default=0) + parser.add_argument("--lora_alpha", type=int, default=16) + parser.add_argument("--target_modules", type=str, nargs="*", default="all-linear") + parser.add_argument("--lora_dropout", type=float, default=0) + + # PPO + parser.add_argument("--save_path", type=str, default="./ckpt") + parser.add_argument("--num_episodes", type=int, default=1) + parser.add_argument("--rollout_batch_size", type=int, default=1024) + parser.add_argument("--micro_rollout_batch_size", type=int, default=8) + parser.add_argument("--max_epochs", type=int, default=1) + parser.add_argument("--prompt_max_len", type=int, default=1024, help="Max tokens for each prompt") + parser.add_argument("--generate_max_len", type=int, default=1024, help="Max tokens to generate in PPO") + parser.add_argument("--max_len", type=int, default=None, help="deprecated max_len") + parser.add_argument("--max_samples", type=int, default=1e8, help="Max number of samples") + parser.add_argument("--max_norm", type=float, default=1.0, help="Gradient clipping") + parser.add_argument("--l2", type=float, default=0.0, help="weight decay loss") + parser.add_argument("--ptx_coef", type=float, default=0.05, help="PPO-ptx loss coef") + parser.add_argument("--eps_clip", type=float, default=0.2, help="PPO clip range") + parser.add_argument("--value_clip", type=float, default=0.2, help="PPO value clip range") + parser.add_argument("--lambd", type=float, default=1.0, help="PPO GAE lambd") + parser.add_argument("--gamma", type=float, default=1, help="PPO GAE gamma") + parser.add_argument("--micro_train_batch_size", type=int, default=4, help="batch size per GPU") + parser.add_argument("--train_batch_size", type=int, default=128, help="Global training batch size") + parser.add_argument("--normalize_reward", action="store_true", default=False, help="Enable Reward Normazation") + parser.add_argument("--top_p", type=float, default=1.0) + parser.add_argument("--temperature", type=float, default=1.0) + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--freezing_actor_steps", type=int, default=-1, help="Used for critic initialization") + parser.add_argument( + "--n_samples_per_prompt", type=int, default=1, help="number of responses for each prompt in generation" + ) + parser.add_argument("--save_value_network", action="store_true", default=False, help="Save critic model") + parser.add_argument("--actor_learning_rate", type=float, default=1e-6) + parser.add_argument("--critic_learning_rate", type=float, default=9e-6) + parser.add_argument("--lr_warmup_ratio", type=float, default=0.03) + parser.add_argument("--kl_target", type=float, default=None) + parser.add_argument("--init_kl_coef", type=float, default=0.01, help="KL penalty in PPO") + parser.add_argument( + "--use_kl_estimator_k3", + action="store_true", + default=False, + help=( + "Use the k3 estimator in http://joschu.net/blog/kl-approx.html" + "to ensure the KL divergence calculated is non-negative" + ), + ) + parser.add_argument("--aux_loss_coef", type=float, default=0, help="MoE balancing loss") + parser.add_argument("--adam_betas", type=float, nargs=2, default=(0.9, 0.95), help="Betas for Adam optimizer") + parser.add_argument("--reward_clip_range", type=float, nargs=2, default=(-10, 10), help="Reward clip range") + + # Reinforce + parser.add_argument( + "--advantage_estimator", + type=str, + choices=["gae", "reinforce", "rloo"], + default="gae", + help="Choose advantage estimation method: gae, reinforce, rloo", + ) + + # Models + parser.add_argument("--pretrain", type=str, default=None, help="HF model name or path") + parser.add_argument("--reward_pretrain", type=str, default=None, help="HF model name or path") + parser.add_argument("--remote_rm_url", type=str, default=None, help="remote RM API (HTTP)") + parser.add_argument("--critic_pretrain", type=str, default=None, help="HF model name or path") + parser.add_argument("--value_head_prefix", type=str, default="score") + parser.add_argument("--ref_reward_offload", action="store_true", default=False) + + # Custom dataset + parser.add_argument("--prompt_data", type=str, default=None, help="HF dataset name or path") + parser.add_argument( + "--prompt_data_probs", + type=str, + default="1.0", + help="sampling probs for datasets", + ) + parser.add_argument("--prompt_split", type=str, default="train") + parser.add_argument("--pretrain_data", type=str, default=None, help="HF dataset name or path") + parser.add_argument( + "--pretrain_data_probs", + type=str, + default="1.0", + help="sampling probs for datasets", + ) + parser.add_argument("--pretrain_split", type=str, default="train") + + parser.add_argument("--input_key", type=str, default="input", help="JSON dataset key") + parser.add_argument("--input_template", type=str, default=None) + parser.add_argument( + "--apply_chat_template", action="store_true", default=False, help="Use HF tokenizer chat template" + ) + + # wandb parameters + parser.add_argument("--use_wandb", type=str, default=None) + parser.add_argument("--wandb_org", type=str, default=None) + parser.add_argument("--wandb_group", type=str, default=None) + parser.add_argument("--wandb_project", type=str, default="openrlhf_train_ppo") + parser.add_argument( + "--wandb_run_name", + type=str, + default="ppo_%s" % datetime.now().strftime("%m%dT%H:%M"), + ) + + # TensorBoard parameters + parser.add_argument("--use_tensorboard", type=str, default=None, help="TensorBoard logging path") + + # performance tuning + parser.add_argument("--perf", action="store_true", default=False) + + args = parser.parse_args() + + if args.advantage_estimator not in ["gae"]: + args.critic_pretrain = None + elif args.critic_pretrain is None: + if not args.remote_rm_url: + args.critic_pretrain = args.reward_pretrain.split(",")[0] + else: + args.critic_pretrain = args.pretrain + + if args.advantage_estimator == "rloo": + assert args.n_samples_per_prompt > 1, "RLOO requires n_samples_per_prompt > 1" + + if args.remote_rm_url: + args.remote_rm_url = args.remote_rm_url.split(",") + + if args.vllm_num_engines >= 1 and args.enable_prefix_caching: + import vllm + if vllm.__version__ < "0.7.0": + args.enable_prefix_caching = False + print("[Warning] Disable prefix cache because vLLM updates weights without updating the old KV Cache for vLLM version below 0.7.0.") + + if args.input_template and "{}" not in args.input_template: + print("[Warning] {} not in args.input_template, set to None") + args.input_template = None + + if args.input_template and "\\n" in args.input_template: + print( + "[Warning] input_template contains \\n chracters instead of newline. " + "You likely want to pass $'\\n' in Bash or \"`n\" in PowerShell." + ) + + if args.packing_samples: + if not args.flash_attn: + print("[Warning] Please --flash_attn to accelerate when --packing_samples is enabled.") + args.flash_attn = True + assert args.vllm_num_engines > 0, "Only support `--packing_samples` with vLLM." + assert not args.pretrain_data, "`--pretrain_data` is not supported with `--packing_samples` yet." + + train(args) diff --git a/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/cli/train_prm.py b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/cli/train_prm.py new file mode 100644 index 0000000000..d5a2b7ce13 --- /dev/null +++ b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/cli/train_prm.py @@ -0,0 +1,204 @@ +import argparse +import math +import os +from datetime import datetime + +from transformers.trainer import get_scheduler + +from openrlhf.datasets import ProcessRewardDataset +from openrlhf.models import Actor +from openrlhf.trainer import ProcessRewardModelTrainer +from openrlhf.utils import blending_datasets, get_strategy, get_tokenizer + + +def train(args): + # configure strategy + strategy = get_strategy(args) + strategy.setup_distributed() + + # configure model + # load huggingface model + model = Actor( + args.pretrain, + use_flash_attention_2=args.flash_attn, + bf16=args.bf16, + load_in_4bit=args.load_in_4bit, + lora_rank=args.lora_rank, + lora_alpha=args.lora_alpha, + target_modules=args.target_modules, + lora_dropout=args.lora_dropout, + ds_config=strategy.get_ds_train_config(is_actor=True), + packing_samples=args.packing_samples, + ) + # configure tokenizer + tokenizer = get_tokenizer(args.pretrain, model.model, "right", strategy, use_fast=not args.disable_fast_tokenizer) + strategy.print(model) + + # gradient_checkpointing + if args.gradient_checkpointing: + model.gradient_checkpointing_enable( + gradient_checkpointing_kwargs={"use_reentrant": args.gradient_checkpointing_use_reentrant} + ) + + # configure optimizer + optim = strategy.create_optimizer(model, lr=args.learning_rate, betas=args.adam_betas, weight_decay=args.l2) + + # prepare for data and dataset + train_data, eval_data = blending_datasets( + args.dataset, + args.dataset_probs, + strategy, + args.seed, + max_count=args.max_samples, + train_split=args.train_split, + eval_split=args.eval_split, + ) + train_data = train_data.select(range(min(args.max_samples, len(train_data)))) + eval_data = eval_data.select(range(min(args.max_samples, len(eval_data)))) + train_dataset = ProcessRewardDataset(train_data, tokenizer, args.max_len, strategy) + eval_dataset = ProcessRewardDataset(eval_data, tokenizer, args.max_len, strategy) + + # prepare dataloader + train_dataloader = strategy.setup_dataloader( + train_dataset, + args.micro_train_batch_size, + True, + True, + train_dataset.packing_collate_fn if args.packing_samples else train_dataset.collate_fn, + ) + eval_dataloader = strategy.setup_dataloader( + eval_dataset, + args.micro_train_batch_size, + True, + False, + eval_dataset.packing_collate_fn if args.packing_samples else eval_dataset.collate_fn, + ) + + # scheduler + num_update_steps_per_epoch = len(train_dataset) // args.train_batch_size + max_steps = math.ceil(args.max_epochs * num_update_steps_per_epoch) + + scheduler = get_scheduler( + args.lr_scheduler, + optim, + num_warmup_steps=math.ceil(max_steps * args.lr_warmup_ratio), + num_training_steps=max_steps, + scheduler_specific_kwargs={"min_lr": args.learning_rate * 0.1}, + ) + + # prepare models + (model, optim, scheduler) = strategy.prepare((model, optim, scheduler)) + + # load checkpoint + consumed_samples = 0 + if args.load_checkpoint and os.path.exists(args.ckpt_path): + _, states = strategy.load_ckpt(model.model, args.ckpt_path) + consumed_samples = states["consumed_samples"] + strategy.print(f"Loaded the checkpoint: {args.ckpt_path}, consumed_samples: {consumed_samples}") + + os.makedirs(args.save_path, exist_ok=True) + + # configure Trainer + trainer = ProcessRewardModelTrainer( + model=model, + strategy=strategy, + optim=optim, + train_dataloader=train_dataloader, + eval_dataloader=eval_dataloader, + scheduler=scheduler, + max_norm=args.max_norm, + batch_size=args.train_batch_size, + max_epochs=args.max_epochs, + tokenizer=tokenizer, + ) + + trainer.fit(args, consumed_samples, num_update_steps_per_epoch) + + # save model checkpoint after fitting on only rank0 + strategy.save_model(model, tokenizer, args.save_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Checkpoint + parser.add_argument("--save_path", type=str, default="./ckpt") + parser.add_argument("--save_steps", type=int, default=-1) + parser.add_argument("--logging_steps", type=int, default=1) + parser.add_argument("--eval_steps", type=int, default=-1) + parser.add_argument("--ckpt_path", type=str, default="./ckpt/checkpoints_prm") + parser.add_argument("--max_ckpt_num", type=int, default=3) + parser.add_argument("--max_ckpt_mem", type=int, default=1e8) + parser.add_argument("--load_checkpoint", action="store_true", default=False) + + # DeepSpeed + parser.add_argument("--max_norm", type=float, default=1.0, help="Gradient clipping") + parser.add_argument("--gradient_checkpointing", action="store_true", default=False) + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--local_rank", type=int, default=-1, help="local_rank for deepspeed") + parser.add_argument("--zero_stage", type=int, default=2, help="DeepSpeed ZeRO stage") + parser.add_argument("--bf16", action="store_true", default=False, help="Enable bfloat16") + parser.add_argument("--zpg", type=int, default=1, help="ZeRO++ max partition size") + parser.add_argument("--adam_offload", action="store_true", default=False, help="Offload Adam Optimizer") + parser.add_argument("--flash_attn", action="store_true", default=False, help="Enable FlashAttention2") + parser.add_argument("--grad_accum_dtype", type=str, default=None, help="Adam grad accum data type") + parser.add_argument("--overlap_comm", action="store_true", default=False) + parser.add_argument("--gradient_checkpointing_use_reentrant", action="store_true", default=False) + parser.add_argument("--disable_fast_tokenizer", action="store_true", default=False) + + # LoRA + parser.add_argument("--load_in_4bit", action="store_true", default=False) + parser.add_argument("--lora_rank", type=int, default=0) + parser.add_argument("--lora_alpha", type=int, default=16) + parser.add_argument("--lora_dropout", type=float, default=0) + parser.add_argument("--target_modules", type=str, nargs="*", default="all-linear") + + # PRM training + parser.add_argument("--pretrain", type=str, default=None) + parser.add_argument("--max_epochs", type=int, default=1) + parser.add_argument("--aux_loss_coef", type=float, default=0, help="MoE balancing loss") + parser.add_argument("--learning_rate", type=float, default=1e-6) + parser.add_argument("--lr_warmup_ratio", type=float, default=0.03) + parser.add_argument("--micro_train_batch_size", type=int, default=1) + parser.add_argument("--train_batch_size", type=int, default=128, help="Global training batch size") + parser.add_argument("--lr_scheduler", type=str, default="cosine_with_min_lr") + parser.add_argument("--l2", type=float, default=0.0, help="weight decay loss") + parser.add_argument("--adam_betas", type=float, nargs=2, default=(0.9, 0.95), help="Betas for Adam optimizer") + parser.add_argument("--placeholder_token", type=str, default=None) + parser.add_argument("--reward_tokens", type=str, nargs="*", default=None) + + # packing samples using Flash Attention2 + parser.add_argument("--packing_samples", action="store_true", default=False) + + # custom dataset + parser.add_argument("--dataset", type=str, default=None) + parser.add_argument("--dataset_probs", type=str, default="1.0", help="sampling probs for datasets") + parser.add_argument("--train_split", type=str, default="train", help="train split of the HF dataset") + parser.add_argument("--eval_split", type=str, default="test", help="test split of the dataset") + + parser.add_argument("--input_key", type=str, default="input", help="JSON dataset key") + parser.add_argument("--label_key", type=str, default="label", help="JSON dataset key") + parser.add_argument("--max_samples", type=int, default=1e8, help="Max number of samples") + parser.add_argument("--max_len", type=int, default=2048, help="Max tokens for the samples") + + # wandb parameters + parser.add_argument("--use_wandb", type=str, default=None) + parser.add_argument("--wandb_org", type=str, default=None) + parser.add_argument("--wandb_group", type=str, default=None) + parser.add_argument("--wandb_project", type=str, default="openrlhf_train_prm") + parser.add_argument( + "--wandb_run_name", + type=str, + default="prm_%s" % datetime.now().strftime("%m%dT%H:%M"), + ) + + args = parser.parse_args() + + # Add positive token and negative token to reward_tokens and remove duplicates + if args.reward_tokens is not None: + print( + "If you are running with soft labels (float values), " + f"the first token in reward_tokens ({args.reward_tokens[0]}) should be the positive token " + "and the second token should be the negative token." + ) + + train(args) diff --git a/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/cli/train_rm.py b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/cli/train_rm.py new file mode 100644 index 0000000000..5e23b19c8f --- /dev/null +++ b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/cli/train_rm.py @@ -0,0 +1,261 @@ +import argparse +import math +import os +from datetime import datetime + +from transformers.trainer import get_scheduler + +from openrlhf.datasets import RewardDataset +from openrlhf.models import get_llm_for_sequence_regression +from openrlhf.trainer import RewardModelTrainer +from openrlhf.utils import blending_datasets, get_strategy, get_tokenizer + + +def train(args): + # configure strategy + strategy = get_strategy(args) + strategy.setup_distributed() + + # configure model + # load huggingface model/config + model = get_llm_for_sequence_regression( + args.pretrain, + "reward", + use_flash_attention_2=args.flash_attn, + bf16=args.bf16, + load_in_4bit=args.load_in_4bit, + lora_rank=args.lora_rank, + lora_alpha=args.lora_alpha, + target_modules=args.target_modules, + lora_dropout=args.lora_dropout, + ds_config=strategy.get_ds_train_config(is_actor=False), + init_value_head=True, + value_head_prefix=args.value_head_prefix, + packing_samples=args.packing_samples, + ) + + # configure tokenizer + tokenizer = get_tokenizer(args.pretrain, model, "left", strategy, use_fast=not args.disable_fast_tokenizer) + + strategy.print(model) + + # configure optimizer + optim = strategy.create_optimizer(model, lr=args.learning_rate, betas=args.adam_betas, weight_decay=args.l2) + + # prepare for data and dataset + train_data, eval_data = blending_datasets( + args.dataset, + args.dataset_probs, + strategy, + args.seed, + max_count=args.max_samples, + stopping_strategy="all_exhausted", + train_split=args.train_split, + eval_split=args.eval_split, + ) + train_data = train_data.select(range(min(args.max_samples, len(train_data)))) + eval_data = eval_data.select(range(min(args.max_samples, len(eval_data)))) + train_dataset = RewardDataset( + train_data, + tokenizer, + args.max_len, + strategy, + input_template=args.input_template, + multiple_of=args.ring_attn_size, + ) + eval_dataset = RewardDataset( + eval_data, + tokenizer, + args.max_len, + strategy, + input_template=args.input_template, + multiple_of=args.ring_attn_size, + ) + + train_dataloader = strategy.setup_dataloader( + train_dataset, + args.micro_train_batch_size, + True, + True, + train_dataset.packing_collate_fn if args.packing_samples else train_dataset.collate_fn, + ) + eval_dataloader = strategy.setup_dataloader( + eval_dataset, + args.micro_train_batch_size, + True, + False, + eval_dataset.packing_collate_fn if args.packing_samples else eval_dataset.collate_fn, + ) + + # scheduler + num_update_steps_per_epoch = len(train_dataset) // args.train_batch_size + max_steps = math.ceil(args.max_epochs * num_update_steps_per_epoch) + + scheduler = get_scheduler( + "cosine_with_min_lr", + optim, + num_warmup_steps=math.ceil(max_steps * args.lr_warmup_ratio), + num_training_steps=max_steps, + scheduler_specific_kwargs={"min_lr": args.learning_rate * 0.1}, + ) + + # gradient_checkpointing + if args.gradient_checkpointing: + model.gradient_checkpointing_enable( + gradient_checkpointing_kwargs={"use_reentrant": args.gradient_checkpointing_use_reentrant} + ) + + # strategy prepare + (model, optim, scheduler) = strategy.prepare((model, optim, scheduler)) + + # load checkpoint + consumed_samples = 0 + if args.load_checkpoint and os.path.exists(args.ckpt_path): + _, states = strategy.load_ckpt(model, args.ckpt_path) + consumed_samples = states["consumed_samples"] + strategy.print(f"Loaded the checkpoint: {args.ckpt_path}, consumed_samples: {consumed_samples}") + + os.makedirs(args.save_path, exist_ok=True) + + # batch_size here is micro_batch_size * 2 + # we use merged chosen + rejected response forward + trainer = RewardModelTrainer( + model=model, + strategy=strategy, + optim=optim, + tokenizer=tokenizer, + train_dataloader=train_dataloader, + eval_dataloader=eval_dataloader, + scheduler=scheduler, + max_norm=args.max_norm, + max_epochs=args.max_epochs, + loss=args.loss, + ) + + trainer.fit(args, consumed_samples, num_update_steps_per_epoch) + + # Save value_head_prefix + strategy.print("Save value_head_prefix in config") + unwrap_model = strategy._unwrap_model(model) + unwrap_model.config.value_head_prefix = args.value_head_prefix + + # save model checkpoint after fitting on only rank0 + strategy.save_model(model, tokenizer, args.save_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + # Checkpoint + parser.add_argument("--save_path", type=str, default="./ckpt") + parser.add_argument("--save_steps", type=int, default=-1) + parser.add_argument("--logging_steps", type=int, default=1) + parser.add_argument("--eval_steps", type=int, default=-1) + parser.add_argument("--ckpt_path", type=str, default="./ckpt/checkpoints_rm") + parser.add_argument("--max_ckpt_num", type=int, default=3) + parser.add_argument("--max_ckpt_mem", type=int, default=1e8) + parser.add_argument("--load_checkpoint", action="store_true", default=False) + + # DeepSpeed + parser.add_argument("--max_norm", type=float, default=1.0, help="Gradient clipping") + parser.add_argument("--gradient_checkpointing", action="store_true", default=False) + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--local_rank", type=int, default=-1, help="local_rank for deepspeed") + parser.add_argument("--zero_stage", type=int, default=2, help="DeepSpeed ZeRO stage") + parser.add_argument("--bf16", action="store_true", default=False, help="Enable bfloat16") + parser.add_argument("--zpg", type=int, default=1, help="ZeRO++ max partition size") + parser.add_argument("--adam_offload", action="store_true", default=False, help="Offload Adam Optimizer") + parser.add_argument("--flash_attn", action="store_true", default=False, help="Enable FlashAttention2") + parser.add_argument("--grad_accum_dtype", type=str, default=None, help="Adam grad accum data type") + parser.add_argument("--overlap_comm", action="store_true", default=False) + parser.add_argument("--gradient_checkpointing_use_reentrant", action="store_true", default=False) + parser.add_argument("--disable_fast_tokenizer", action="store_true", default=False) + + # Models + parser.add_argument("--pretrain", type=str, default=None) + parser.add_argument("--value_head_prefix", type=str, default="score") + + # Context Parallel + parser.add_argument("--ring_attn_size", type=int, default=1, help="Ring attention group size") + parser.add_argument( + "--ring_head_stride", + type=int, + default=1, + help="the number of heads to do ring attention each time. " + "It should be a divisor of the number of heads. " + "A larger value may results in faster training but will consume more memory.", + ) + + # LoRA + parser.add_argument("--load_in_4bit", action="store_true", default=False) + parser.add_argument("--lora_rank", type=int, default=0) + parser.add_argument("--lora_alpha", type=int, default=16) + parser.add_argument("--lora_dropout", type=float, default=0) + parser.add_argument("--target_modules", type=str, nargs="*", default="all-linear") + + # RM training + parser.add_argument("--max_epochs", type=int, default=1) + parser.add_argument("--aux_loss_coef", type=float, default=0, help="MoE balancing loss") + parser.add_argument("--compute_fp32_loss", action="store_true", default=False) + parser.add_argument("--margin_loss", action="store_true", default=False) + parser.add_argument("--learning_rate", type=float, default=9e-6) + parser.add_argument("--lr_warmup_ratio", type=float, default=0.03) + parser.add_argument("--micro_train_batch_size", type=int, default=1) + parser.add_argument("--train_batch_size", type=int, default=128, help="Global training batch size") + parser.add_argument("--loss", type=str, default="sigmoid") + parser.add_argument("--l2", type=float, default=0.0, help="weight decay loss") + parser.add_argument("--adam_betas", type=float, nargs=2, default=(0.9, 0.95), help="Betas for Adam optimizer") + + # packing samples using Flash Attention2 + parser.add_argument("--packing_samples", action="store_true", default=False) + + # Custom dataset + parser.add_argument("--dataset", type=str, default=None) + parser.add_argument("--dataset_probs", type=str, default="1.0", help="sampling probs for datasets") + parser.add_argument("--prompt_key", type=str, default=None) + parser.add_argument("--chosen_key", type=str, default="chosen") + parser.add_argument("--rejected_key", type=str, default="rejected") + parser.add_argument("--input_template", type=str, default=None) + parser.add_argument( + "--apply_chat_template", action="store_true", default=False, help="Use HF tokenizer chat template" + ) + parser.add_argument("--tokenizer_chat_template", type=str, default=None) + parser.add_argument("--train_split", type=str, default="train", help="train split of the HF dataset") + parser.add_argument("--eval_split", type=str, default="test", help="test split of the dataset") + parser.add_argument("--max_samples", type=int, default=1e8, help="Max number of samples") + parser.add_argument("--max_len", type=int, default=512) + + # wandb parameters + parser.add_argument("--use_wandb", type=str, default=None) + parser.add_argument("--wandb_org", type=str, default=None) + parser.add_argument("--wandb_group", type=str, default=None) + parser.add_argument("--wandb_project", type=str, default="openrlhf_train_rm") + parser.add_argument( + "--wandb_run_name", + type=str, + default="rm_%s" % datetime.now().strftime("%m%dT%H:%M"), + ) + + # TensorBoard parameters + parser.add_argument("--use_tensorboard", type=str, default=None, help="TensorBoard logging path") + + args = parser.parse_args() + + if args.input_template and "{}" not in args.input_template: + print("[Warning] {} not in args.input_template, set to None") + args.input_template = None + + if args.input_template and "\\n" in args.input_template: + print( + "[Warning] input_template contains \\n chracters instead of newline. " + "You likely want to pass $'\\n' in Bash or \"`n\" in PowerShell." + ) + + if args.packing_samples and not args.flash_attn: + print("[Warning] Please --flash_attn to accelerate when --packing_samples is enabled.") + args.flash_attn = True + + if args.ring_attn_size > 1: + assert args.packing_samples, "packing_samples must be enabled when using ring attention" + + train(args) diff --git a/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/cli/train_sft.py b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/cli/train_sft.py new file mode 100644 index 0000000000..843e37adad --- /dev/null +++ b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/cli/train_sft.py @@ -0,0 +1,258 @@ +import argparse +import math +import os +from datetime import datetime + +from transformers.trainer import get_scheduler + +from openrlhf.datasets import SFTDataset +from openrlhf.models import Actor +from openrlhf.trainer import SFTTrainer +from openrlhf.utils import blending_datasets, get_strategy, get_tokenizer + + +def train(args): + # configure strategy + strategy = get_strategy(args) + strategy.setup_distributed() + + # configure model + # load huggingface model + model = Actor( + args.pretrain, + use_flash_attention_2=args.flash_attn, + bf16=args.bf16, + load_in_4bit=args.load_in_4bit, + lora_rank=args.lora_rank, + lora_alpha=args.lora_alpha, + target_modules=args.target_modules, + lora_dropout=args.lora_dropout, + ds_config=strategy.get_ds_train_config(is_actor=True), + packing_samples=args.packing_samples, + ) + # configure tokenizer + tokenizer = get_tokenizer(args.pretrain, model.model, "right", strategy, use_fast=not args.disable_fast_tokenizer) + strategy.print(model) + + # gradient_checkpointing + if args.gradient_checkpointing: + model.gradient_checkpointing_enable( + gradient_checkpointing_kwargs={"use_reentrant": args.gradient_checkpointing_use_reentrant} + ) + + # configure optimizer + optim = strategy.create_optimizer(model, lr=args.learning_rate, betas=args.adam_betas, weight_decay=args.l2) + + # prepare for data and dataset + train_data, eval_data = blending_datasets( + args.dataset, + args.dataset_probs, + strategy, + args.seed, + max_count=args.max_samples, + train_split=args.train_split, + eval_split=args.eval_split, + ) + train_data = train_data.select(range(min(args.max_samples, len(train_data)))) + eval_data = eval_data.select(range(min(args.max_samples, len(eval_data)))) + train_dataset = SFTDataset( + train_data, + tokenizer, + args.max_len, + strategy, + pretrain_mode=args.pretrain_mode, + input_template=args.input_template, + multiple_of=args.ring_attn_size, + multiturn=args.multiturn, + ) + eval_dataset = SFTDataset( + eval_data, + tokenizer, + args.max_len, + strategy, + pretrain_mode=args.pretrain_mode, + input_template=args.input_template, + multiple_of=args.ring_attn_size, + multiturn=args.multiturn, + ) + + # prepare dataloader + train_dataloader = strategy.setup_dataloader( + train_dataset, + args.micro_train_batch_size, + True, + True, + train_dataset.packing_collate_fn if args.packing_samples else train_dataset.collate_fn, + ) + eval_dataloader = strategy.setup_dataloader( + eval_dataset, + args.micro_train_batch_size, + True, + False, + eval_dataset.packing_collate_fn if args.packing_samples else eval_dataset.collate_fn, + ) + + # scheduler + num_update_steps_per_epoch = len(train_dataset) // args.train_batch_size + max_steps = math.ceil(args.max_epochs * num_update_steps_per_epoch) + + scheduler = get_scheduler( + args.lr_scheduler, + optim, + num_warmup_steps=math.ceil(max_steps * args.lr_warmup_ratio), + num_training_steps=max_steps, + scheduler_specific_kwargs={"min_lr": args.learning_rate * 0.1}, + ) + + # prepare models + (model, optim, scheduler) = strategy.prepare((model, optim, scheduler)) + + # load checkpoint + consumed_samples = 0 + if args.load_checkpoint and os.path.exists(args.ckpt_path): + _, states = strategy.load_ckpt(model.model, args.ckpt_path) + consumed_samples = states["consumed_samples"] + strategy.print(f"Loaded the checkpoint: {args.ckpt_path}, consumed_samples: {consumed_samples}") + + os.makedirs(args.save_path, exist_ok=True) + + # configure Trainer + trainer = SFTTrainer( + model=model, + strategy=strategy, + optim=optim, + train_dataloader=train_dataloader, + eval_dataloader=eval_dataloader, + scheduler=scheduler, + max_norm=args.max_norm, + pretrain_mode=args.pretrain_mode, + batch_size=args.train_batch_size, + max_epochs=args.max_epochs, + tokenizer=tokenizer, + save_hf_ckpt=args.save_hf_ckpt, + disable_ds_ckpt=args.disable_ds_ckpt, + ) + + trainer.fit(args, consumed_samples, num_update_steps_per_epoch) + + # save model checkpoint after fitting on only rank0 + strategy.save_model(model, tokenizer, args.save_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Checkpoint + parser.add_argument("--save_path", type=str, default="./ckpt") + parser.add_argument("--save_steps", type=int, default=-1) + parser.add_argument("--save_hf_ckpt", action="store_true", default=False) + parser.add_argument("--disable_ds_ckpt", action="store_true", default=False) + parser.add_argument("--logging_steps", type=int, default=1) + parser.add_argument("--eval_steps", type=int, default=-1) + parser.add_argument("--ckpt_path", type=str, default="./ckpt/checkpoints_sft") + parser.add_argument("--max_ckpt_num", type=int, default=3) + parser.add_argument("--max_ckpt_mem", type=int, default=1e8) + parser.add_argument("--load_checkpoint", action="store_true", default=False) + + # DeepSpeed + parser.add_argument("--micro_train_batch_size", type=int, default=8, help="batch size per GPU") + parser.add_argument("--train_batch_size", type=int, default=128, help="Global training batch size") + parser.add_argument("--max_norm", type=float, default=1.0, help="Gradient clipping") + parser.add_argument("--gradient_checkpointing", action="store_true", default=False) + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--local_rank", type=int, default=-1, help="local_rank for deepspeed") + parser.add_argument("--zero_stage", type=int, default=2, help="DeepSpeed ZeRO stage") + parser.add_argument("--bf16", action="store_true", default=False, help="Enable bfloat16") + parser.add_argument("--zpg", type=int, default=1, help="ZeRO++ max partition size") + parser.add_argument("--adam_offload", action="store_true", default=False, help="Offload Adam Optimizer") + parser.add_argument("--flash_attn", action="store_true", default=False, help="Enable FlashAttention2") + parser.add_argument("--grad_accum_dtype", type=str, default=None, help="Adam grad accum data type") + parser.add_argument("--overlap_comm", action="store_true", default=False) + parser.add_argument("--gradient_checkpointing_use_reentrant", action="store_true", default=False) + parser.add_argument("--disable_fast_tokenizer", action="store_true", default=False) + + # SFT + parser.add_argument("--max_epochs", type=int, default=2) + parser.add_argument("--aux_loss_coef", type=float, default=0, help="MoE balancing loss") + parser.add_argument("--pretrain", type=str, default=None) + parser.add_argument("--learning_rate", type=float, default=5e-6) + parser.add_argument("--lr_warmup_ratio", type=float, default=0.03) + parser.add_argument("--pretrain_mode", action="store_true", default=False, help="Use pretrain loss") + parser.add_argument("--lr_scheduler", type=str, default="cosine_with_min_lr") + parser.add_argument("--l2", type=float, default=0, help="weight decay loss") + parser.add_argument("--adam_betas", type=float, nargs=2, default=(0.9, 0.95), help="Betas for Adam optimizer") + + # ring-attention + parser.add_argument("--ring_attn_size", type=int, default=1, help="Ring attention group size") + parser.add_argument( + "--ring_head_stride", + type=int, + default=1, + help="the number of heads to do ring attention each time. " + "It should be a divisor of the number of heads. " + "A larger value may results in faster training but will consume more memory.", + ) + + # LoRA + parser.add_argument("--load_in_4bit", action="store_true", default=False) + parser.add_argument("--lora_rank", type=int, default=0) + parser.add_argument("--lora_alpha", type=int, default=16) + parser.add_argument("--target_modules", type=str, nargs="*", default="all-linear") + parser.add_argument("--lora_dropout", type=float, default=0) + + # packing SFT samples without CrossAttention + parser.add_argument("--packing_samples", action="store_true", default=False) + + # custom dataset + parser.add_argument("--dataset", type=str, default=None) + parser.add_argument("--dataset_probs", type=str, default="1.0", help="sampling probs for datasets") + parser.add_argument("--train_split", type=str, default="train", help="train split of the HF dataset") + parser.add_argument("--eval_split", type=str, default="test", help="test split of the dataset") + parser.add_argument("--multiturn", action="store_true", default=False, help="Use compacted multiturn dataset") + + parser.add_argument("--input_key", type=str, default="input", help="JSON dataset key") + parser.add_argument("--output_key", type=str, default=None, help="JSON dataset key") + parser.add_argument("--input_template", type=str, default="User: {}\nAssistant: ") + parser.add_argument( + "--apply_chat_template", action="store_true", default=False, help="Use HF tokenizer chat template" + ) + parser.add_argument("--tokenizer_chat_template", type=str, default=None) + parser.add_argument("--max_samples", type=int, default=1e8, help="Max number of samples") + parser.add_argument("--max_len", type=int, default=2048, help="Max tokens for the samples") + + # wandb parameters + parser.add_argument("--use_wandb", type=str, default=None) + parser.add_argument("--wandb_org", type=str, default=None) + parser.add_argument("--wandb_group", type=str, default=None) + parser.add_argument("--wandb_project", type=str, default="openrlhf_train_sft") + parser.add_argument( + "--wandb_run_name", + type=str, + default="sft_%s" % datetime.now().strftime("%m%dT%H:%M"), + ) + + # TensorBoard parameters + parser.add_argument("--use_tensorboard", type=str, default=None, help="TensorBoard logging path") + + args = parser.parse_args() + + if args.multiturn: + assert args.apply_chat_template, "apply_chat_template must be enabled when using multiturn format" + + if args.input_template and "{}" not in args.input_template: + print("[Warning] {} not in args.input_template, set to None") + args.input_template = None + + if args.input_template and "\\n" in args.input_template: + print( + "[Warning] input_template contains \\n chracters instead of newline. " + "You likely want to pass $'\\n' in Bash or \"`n\" in PowerShell." + ) + + if args.packing_samples and not args.flash_attn: + print("[Warning] Please --flash_attn to accelerate when --packing_samples is enabled.") + args.flash_attn = True + + if args.ring_attn_size > 1: + assert args.packing_samples, "packing_samples must be enabled when using ring attention" + + train(args) diff --git a/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/datasets/__init__.py b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/datasets/__init__.py new file mode 100644 index 0000000000..bbb762f1ea --- /dev/null +++ b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/datasets/__init__.py @@ -0,0 +1,7 @@ +from .process_reward_dataset import ProcessRewardDataset +from .prompts_dataset import PromptDataset +from .reward_dataset import RewardDataset +from .sft_dataset import SFTDataset +from .unpaired_preference_dataset import UnpairedPreferenceDataset + +__all__ = ["ProcessRewardDataset", "PromptDataset", "RewardDataset", "SFTDataset", "UnpairedPreferenceDataset"] diff --git a/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/datasets/process_reward_dataset.py b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/datasets/process_reward_dataset.py new file mode 100644 index 0000000000..5936384bac --- /dev/null +++ b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/datasets/process_reward_dataset.py @@ -0,0 +1,136 @@ +import numbers +from typing import Callable + +import torch +import torch.nn.functional as F +from torch.utils.data import Dataset + +from openrlhf.utils.utils import convert_token_to_id +from .utils import zero_pad_sequences + + +class ProcessRewardDataset(Dataset): + """ + Dataset for process reward model + + Args: + dataset: dataset for reward model + self.tokenizer: self.tokenizer for reward model + self.max_length: max length of input + """ + + def __init__( + self, + dataset, + tokenizer: Callable, + max_length: int, + strategy, + multiple_of=1, + ) -> None: + super().__init__() + self.tokenizer = tokenizer + self.strategy = strategy + self.max_length = max_length + self.multiple_of = multiple_of + + # chat_template + self.input_key = getattr(self.strategy.args, "input_key", None) + self.label_key = getattr(self.strategy.args, "label_key", None) + self.placeholder_token = getattr(self.strategy.args, "placeholder_token", None) + self.reward_tokens = getattr(self.strategy.args, "reward_tokens", None) + + self.placeholder_token_id = convert_token_to_id(self.placeholder_token, self.tokenizer) + + # Store the processed data in class attributes + self.inputs = dataset[self.input_key] + self.labels = dataset[self.label_key] + + def __len__(self): + length = len(self.inputs) + return length + + def __getitem__(self, idx): + input_token = self.tokenizer( + self.inputs[idx], + max_length=self.max_length, + padding=False, + truncation=True, + return_tensors="pt", + add_special_tokens=False, + ) + + input_ids = input_token["input_ids"] + label_values = self.labels[idx] + assert isinstance(label_values, list), "labels should be a list of strings or numbers" + if isinstance(label_values[0], str): + label_tokens = [] + for label in label_values: + assert ( + self.reward_tokens is None or label in self.reward_tokens + ), f"label should be in reward tokens {self.reward_tokens}, got {label}" + label_tokens.append(convert_token_to_id(label, self.tokenizer)) + + # label_tokens is list of token id (for '+', '-', etc) + label_tensor = torch.tensor(label_tokens, dtype=input_ids.dtype) + else: + # label_values is list of float numbers (for reward values) + label_tensor = torch.tensor(label_values, dtype=torch.float) + # Motivation: inputs_ids maybe truncated to self.max_length, where placeholder_tokens at the end may be removed. + # We should also truncate the labels to match the length of input_ids + # Step 1: Create a mask for placeholder token positions + mask = input_ids == self.placeholder_token_id + # Step 2: Ensure that label_tensor is truncated along the last dimension + # Find the length of the last dimension of the mask + num_placeholders = mask.sum(dim=-1) + # Truncate label_tensor along the last dimension to match num_placeholders + truncated_labels = label_tensor[..., :num_placeholders.max()] + # Step 3: Update labels at placeholder token positions + labels = torch.full_like(input_ids, -100) + labels[mask] = truncated_labels + + return ( + input_ids, + input_token["attention_mask"], + labels, + ) + + def collate_fn(self, item_list): + input_ids = [] + input_masks = [] + label_ids = [] + for input_id, input_mask, label_id in item_list: + input_ids.append(input_id) + input_masks.append(input_mask) + label_ids.append(label_id) + + padding_side = "right" + input_ids = zero_pad_sequences(input_ids, side=padding_side, value=self.tokenizer.pad_token_id) + input_masks = zero_pad_sequences(input_masks, side=padding_side) + label_ids = zero_pad_sequences(label_ids, side=padding_side, value=self.tokenizer.pad_token_id) + return input_ids, input_masks, label_ids + + def packing_collate_fn(self, item_list): + input_ids = [] + input_att_masks = [] + input_seq_lens = [] + label_ids = [] + index = 1 + for input_id, input_mask, label_id in item_list: + input_ids.append(input_id.flatten()) + input_att_masks.append(torch.full_like(input_id.flatten(), index)) + input_seq_lens.append(len(input_id.flatten())) + + label_ids.append(label_id.flatten()) + index += 1 + + packed_input_ids = torch.cat(input_ids, dim=0).unsqueeze(0) + packed_attention_masks = torch.cat(input_att_masks, dim=0).unsqueeze(0) + packed_seq_lens = input_seq_lens + packed_label_ids = torch.cat(label_ids, dim=0).unsqueeze(0) + + if self.multiple_of > 1 and packed_input_ids.numel() % self.multiple_of != 0: + padding_len = self.multiple_of - (packed_input_ids.numel() % self.multiple_of) + packed_input_ids = F.pad(packed_input_ids, (0, padding_len), value=self.tokenizer.pad_token_id) + packed_attention_masks = F.pad(packed_attention_masks, (0, padding_len), value=0) + + return packed_input_ids, packed_attention_masks, packed_seq_lens, packed_label_ids diff --git a/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/datasets/prompts_dataset.py b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/datasets/prompts_dataset.py new file mode 100644 index 0000000000..e1029f6df4 --- /dev/null +++ b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/datasets/prompts_dataset.py @@ -0,0 +1,57 @@ +from torch.utils.data import Dataset +from tqdm import tqdm + + +def preprocess_data(data, input_template=None, input_key="input", apply_chat_template=None) -> str: + if apply_chat_template: + chat = data[input_key] + if isinstance(chat, str): + chat = [{"role": "user", "content": chat}] + prompt = apply_chat_template(chat, tokenize=False, add_generation_prompt=True) + else: + prompt = data[input_key] + if input_template: + prompt = input_template.format(prompt) + return prompt + + +class PromptDataset(Dataset): + """ + Dataset for PPO model + + Args: + dataset: dataset for PPO model + tokenizer: tokenizer for PPO model + max_length: max length of input + """ + + def __init__( + self, + dataset, + tokenizer, + strategy, + input_template=None, + ) -> None: + super().__init__() + self.strategy = strategy + self.tokenizer = tokenizer + + # chat_template + self.input_template = input_template + input_key = getattr(self.strategy.args, "input_key", None) + apply_chat_template = getattr(self.strategy.args, "apply_chat_template", False) + + if apply_chat_template: + apply_chat_template = self.tokenizer.apply_chat_template + + self.prompts = [] + for data in tqdm(dataset, desc="Preprocessing data", disable=not self.strategy.is_rank_0()): + prompt = preprocess_data(data, input_template, input_key, apply_chat_template) + self.prompts.append(prompt) + + def __len__(self): + length = len(self.prompts) + return length + + def __getitem__(self, idx): + return self.prompts[idx] diff --git a/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/datasets/reward_dataset.py b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/datasets/reward_dataset.py new file mode 100644 index 0000000000..c8fcef0dc1 --- /dev/null +++ b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/datasets/reward_dataset.py @@ -0,0 +1,235 @@ +from typing import Callable + +import torch +import torch.nn.functional as F +from torch.utils.data import Dataset + +from .utils import exist_and_not_none, zero_pad_sequences + + +def preprocess_data( + data, + input_template=None, + prompt_key=None, + chosen_key="chosen", + rejected_key="rejected", + apply_chat_template=None, + is_dpo=False, +) -> str: + if apply_chat_template: + if prompt_key: + prompt = apply_chat_template(data[prompt_key], tokenize=False, add_generation_prompt=True) + chosen = apply_chat_template(data[prompt_key] + data[chosen_key], tokenize=False)[len(prompt) :] + rejected = apply_chat_template(data[prompt_key] + data[rejected_key], tokenize=False)[len(prompt) :] + else: + prompt = "" + chosen = apply_chat_template(data[chosen_key], tokenize=False) + rejected = apply_chat_template(data[rejected_key], tokenize=False) + + if is_dpo: + prompt = apply_chat_template(data[chosen_key][:-1], tokenize=False, add_generation_prompt=True) + chosen = chosen[len(prompt) :] + rejected = rejected[len(prompt) :] + else: + if prompt_key: + prompt = data[prompt_key] + if input_template: + prompt = input_template.format(prompt) + else: + prompt = "" + chosen = data[chosen_key] + rejected = data[rejected_key] + + # margin loss + margin = data["margin"] if exist_and_not_none(data, "margin") else 0 + + return prompt, chosen, rejected, margin + + +class RewardDataset(Dataset): + """ + Dataset for reward model + + Args: + dataset: dataset for reward model + self.tokenizer: self.tokenizer for reward model + self.max_length: max length of input + """ + + def __init__( + self, + dataset, + tokenizer: Callable, + max_length: int, + strategy, + input_template=None, + is_dpo=False, + num_processors=8, + multiple_of=1, + ) -> None: + super().__init__() + self.is_dpo = is_dpo + self.tokenizer = tokenizer + self.strategy = strategy + self.max_length = max_length + self.multiple_of = multiple_of + + # chat_template + self.input_template = input_template + self.prompt_key = getattr(self.strategy.args, "prompt_key", None) + self.chosen_key = getattr(self.strategy.args, "chosen_key", None) + self.rejected_key = getattr(self.strategy.args, "rejected_key", None) + self.apply_chat_template = getattr(self.strategy.args, "apply_chat_template", False) + + if self.apply_chat_template: + self.apply_chat_template = self.tokenizer.apply_chat_template + tokenizer_chat_template = getattr(self.strategy.args, "tokenizer_chat_template", None) + if tokenizer_chat_template: + self.tokenizer.chat_template = tokenizer_chat_template + + # Parallel loading datasets + processed_dataset = dataset.map( + self.process_data, remove_columns=dataset.column_names, num_proc=num_processors + ) + + # Filter out None values if necessary + processed_dataset = processed_dataset.filter(lambda x: x["prompt"] is not None) + + # Store the processed data in class attributes + self.prompts = processed_dataset["prompt"] + self.chosens = processed_dataset["chosen"] + self.rejects = processed_dataset["reject"] + self.extras = processed_dataset["extra"] + + def process_data(self, data): + prompt, chosen, reject, margin = preprocess_data( + data, + self.input_template, + self.prompt_key, + self.chosen_key, + self.rejected_key, + self.apply_chat_template, + self.is_dpo, + ) + + if self.is_dpo: + prompt_token = self.tokenizer( + prompt, + max_length=self.max_length, + padding=False, + truncation=True, + return_tensors="pt", + add_special_tokens=False, + ) + prompt_ids_len = prompt_token["attention_mask"].int().sum().item() + + # Filter the sample whose length is greater than max_length (2 for answer length) + if prompt_ids_len >= self.max_length - 2: + prompt = None + + return { + "prompt": prompt, + "chosen": chosen, + "reject": reject, + "extra": prompt_ids_len if self.is_dpo else margin, + } + + def __len__(self): + length = len(self.chosens) + return length + + def __getitem__(self, idx): + prompt, chosen, reject, extra = self.prompts[idx], self.chosens[idx], self.rejects[idx], self.extras[idx] + + chosen = (prompt + chosen).rstrip("\n") + if not chosen.endswith(self.tokenizer.eos_token): + chosen += " " + self.tokenizer.eos_token + chosen_token = self.tokenizer( + chosen, + max_length=self.max_length, + padding=False, + truncation=True, + return_tensors="pt", + add_special_tokens=False, + ) + + reject = (prompt + reject).rstrip("\n") + if not reject.endswith(self.tokenizer.eos_token): + reject += " " + self.tokenizer.eos_token + reject_token = self.tokenizer( + reject, + max_length=self.max_length, + padding=False, + truncation=True, + return_tensors="pt", + add_special_tokens=False, + ) + + # to avoid EOS_token truncation + chosen_token["input_ids"][0][-1] = self.tokenizer.eos_token_id + reject_token["input_ids"][0][-1] = self.tokenizer.eos_token_id + chosen_token["attention_mask"][0][-1] = True + reject_token["attention_mask"][0][-1] = True + + return ( + chosen_token["input_ids"], + chosen_token["attention_mask"], + reject_token["input_ids"], + reject_token["attention_mask"], + extra, + ) + + def collate_fn(self, item_list): + chosen_ids = [] + chosen_masks = [] + reject_ids = [] + rejects_masks = [] + extras = [] + for chosen_id, chosen_mask, reject_id, rejects_mask, extra in item_list: + chosen_ids.append(chosen_id) + chosen_masks.append(chosen_mask) + reject_ids.append(reject_id) + rejects_masks.append(rejects_mask) + extras.append(extra) + + if self.is_dpo: + padding_side = "right" + else: + padding_side = "left" + chosen_ids = zero_pad_sequences(chosen_ids, side=padding_side, value=self.tokenizer.pad_token_id) + chosen_masks = zero_pad_sequences(chosen_masks, side=padding_side) + reject_ids = zero_pad_sequences(reject_ids, side=padding_side, value=self.tokenizer.pad_token_id) + rejects_masks = zero_pad_sequences(rejects_masks, side=padding_side) + return chosen_ids, chosen_masks, reject_ids, rejects_masks, extras + + def packing_collate_fn(self, item_list): + extras = [] + + chosen_ids = [] + chosen_att_masks = [] + chosen_seq_lens = [] + rejected_ids = [] + rejected_att_masks = [] + rejected_seq_lens = [] + index = 1 + for chosen_id, chosen_mask, reject_id, rejects_mask, extra in item_list: + chosen_ids.append(chosen_id.flatten()) + chosen_att_masks.append(torch.full_like(chosen_id.flatten(), index)) + chosen_seq_lens.append(len(chosen_id.flatten())) + extras.append(extra) + + rejected_ids.append(reject_id.flatten()) + rejected_att_masks.append(torch.full_like(reject_id.flatten(), index + len(item_list))) + rejected_seq_lens.append(len(reject_id.flatten())) + index += 1 + + packed_input_ids = torch.cat(chosen_ids + rejected_ids, dim=0).unsqueeze(0) + packed_attention_masks = torch.cat(chosen_att_masks + rejected_att_masks, dim=0).unsqueeze(0) + packed_seq_lens = chosen_seq_lens + rejected_seq_lens + + if self.multiple_of > 1 and packed_input_ids.numel() % self.multiple_of != 0: + padding_len = self.multiple_of - (packed_input_ids.numel() % self.multiple_of) + packed_input_ids = F.pad(packed_input_ids, (0, padding_len), value=self.tokenizer.pad_token_id) + packed_attention_masks = F.pad(packed_attention_masks, (0, padding_len), value=0) + + return packed_input_ids, packed_attention_masks, packed_seq_lens, extras diff --git a/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/datasets/sft_dataset.py b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/datasets/sft_dataset.py new file mode 100644 index 0000000000..6e031f70ab --- /dev/null +++ b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/datasets/sft_dataset.py @@ -0,0 +1,231 @@ +from typing import Callable + +import torch +import torch.nn.functional as F +from torch.utils.data import Dataset + +from .utils import zero_pad_sequences + + +def preprocess_data(data, input_template=None, input_key="input", output_key=None, apply_chat_template=None, multiturn=False): + if apply_chat_template: + if output_key: + prompt_message = data[input_key] + response_message = data[output_key] + + if isinstance(prompt_message, str) and isinstance(response_message, str): + prompt_message = [{"role": "user", "content": prompt_message}] + response_message = [{"role": "assistant", "content": response_message}] + + prompt = apply_chat_template(prompt_message, tokenize=False, add_generation_prompt=True) + response = apply_chat_template(prompt_message + response_message, tokenize=False)[len(prompt) :] + else: + prompt = apply_chat_template(data[input_key][:-1], tokenize=False, add_generation_prompt=True) + response = apply_chat_template(data[input_key], tokenize=False)[len(prompt) :] + else: + prompt = data[input_key] + if input_template: + prompt = input_template.format(prompt) + # output_key is None for continue pretrain + response = data[output_key] if output_key else "" + return prompt, response + + +class SFTDataset(Dataset): + """ + Dataset for SFT model + + Args: + dataset: dataset for SFT model + tokenizer: tokenizer for SFT model + max_length: max length of input + """ + + def __init__( + self, + dataset, + tokenizer: Callable, + max_length: int, + strategy, + input_template=None, + pretrain_mode=False, + num_processors=8, # Specify the number of processors you want to use + multiple_of=1, + multiturn=False, + ) -> None: + super().__init__() + self.tokenizer = tokenizer + self.strategy = strategy + self.pretrain_mode = pretrain_mode + self.max_length = max_length + self.multiple_of = multiple_of + self.multiturn = multiturn + + # chat template + self.input_template = input_template + self.input_key = getattr(self.strategy.args, "input_key", None) + self.output_key = getattr(self.strategy.args, "output_key", None) + self.apply_chat_template = getattr(self.strategy.args, "apply_chat_template", False) + + if self.apply_chat_template: + self.apply_chat_template = self.tokenizer.apply_chat_template + tokenizer_chat_template = getattr(self.strategy.args, "tokenizer_chat_template", None) + if tokenizer_chat_template: + self.tokenizer.chat_template = tokenizer_chat_template + + # Parallel loading datasets + processed_dataset = dataset.map( + self.process_data, + remove_columns=dataset.column_names, + num_proc=num_processors, + ) + processed_dataset = processed_dataset.filter(lambda x: x["prompt"] is not None) + + # Store the processed data in class attributes + self.prompts = processed_dataset["prompt"] + self.responses = processed_dataset["response"] + self.prompt_ids_lens = processed_dataset["prompt_ids_len"] + self.response_ranges = processed_dataset["response_ranges"] if self.multiturn else None + + def process_data(self, data): + if self.multiturn and self.output_key: + data[self.input_key].append(data[self.output_key]) + data[self.output_key] = None + + if self.multiturn: + assert not self.output_key or not data[self.output_key], "You should put the whole trajactory into data[input_key] and do not set output_key" + input_key = self.input_key + apply_chat_template = self.apply_chat_template + response_ranges = [] + for idx, message in enumerate(data[input_key]): + if message['role'] == 'assistant': + prompt = apply_chat_template(data[input_key][: idx], tokenize=False, add_generation_prompt=True) + response = apply_chat_template(data[input_key][: idx + 1], tokenize=False)[len(prompt):] + + start_idx = self.tokenizer( + prompt, + max_length=self.max_length, + padding=False, + truncation=True, + return_tensors="pt", + add_special_tokens=False, + )["attention_mask"].int().sum().item() + + end_idx = start_idx + self.tokenizer( + response, + max_length=self.max_length, + padding=False, + truncation=True, + return_tensors="pt", + add_special_tokens=False, + )["attention_mask"].int().sum().item() - 1 + response_ranges.append((start_idx, end_idx)) # left close right open + + prompt, response = preprocess_data( + data, + None if self.pretrain_mode else self.input_template, + self.input_key, + self.output_key, + apply_chat_template=None if self.pretrain_mode else self.apply_chat_template, + multiturn=self.multiturn, + ) + + if not self.pretrain_mode: + prompt_token = self.tokenizer( + prompt, + max_length=self.max_length, + padding=False, + truncation=True, + return_tensors="pt", + add_special_tokens=False, + ) + prompt_ids_len = prompt_token["attention_mask"].int().sum().item() + + # filter the sample whose length is greater than max_length (2 for answer length) + if not prompt or not response or prompt_ids_len >= self.max_length - 2: + prompt = None + else: + prompt_ids_len = 0 + + return {"prompt": prompt, "response": response, "prompt_ids_len": prompt_ids_len, "response_ranges": response_ranges if self.multiturn else None} + + def __len__(self): + length = len(self.prompts) + return length + + def __getitem__(self, idx): + prompt_ids_len = self.prompt_ids_lens[idx] + prompt = self.prompts[idx] + response = self.responses[idx] + + if not self.pretrain_mode: + text = (prompt + response).rstrip("\n") + if not text.endswith(self.tokenizer.eos_token): + text += " " + self.tokenizer.eos_token + else: + text = prompt + + input_token = self.tokenizer( + text, + max_length=self.max_length, + padding=False, + truncation=True, + return_tensors="pt", + add_special_tokens=False, + ) + + if not self.pretrain_mode: + # to avoid EOS_token truncation + input_token["input_ids"][0][-1] = self.tokenizer.eos_token_id + input_token["attention_mask"][0][-1] = True + info = {"input": prompt, "output": response, "input_length": input_token["attention_mask"].int().sum().item(), "response_ranges": self.response_ranges[idx] if self.multiturn else None} + + return prompt_ids_len, input_token["input_ids"], input_token["attention_mask"], info + + def collate_fn(self, item_list): + prompt_ids_lens = [] + input_ids = [] + attention_masks = [] + infos = {"input": [], "output": []} + + for prompt_ids_len, input_id, attention_mask, info in item_list: + prompt_ids_lens.append(prompt_ids_len) + input_ids.append(input_id) + attention_masks.append(attention_mask) + infos["input"].append(info["input"]) + infos["output"].append(info["output"]) + + input_ids = zero_pad_sequences(input_ids, "right", self.tokenizer.pad_token_id) + attention_masks = zero_pad_sequences(attention_masks, "right") + return prompt_ids_lens, input_ids, attention_masks, infos + + def packing_collate_fn(self, item_list): + packed_input_ids = [] + packed_attention_masks = [] + prompt_ids_lens = [] + infos = {"input_length": [], "response_ranges": [] if self.multiturn else None} + index = 1 + for prompt_ids_len, input_id, attention_mask, info in item_list: + packed_input_ids.append(input_id.flatten()) + packed_attention_masks.append(torch.full_like(input_id.flatten(), index)) + prompt_ids_lens.append(prompt_ids_len) + infos["input_length"].append(info["input_length"]) + if self.multiturn: + if len(infos["response_ranges"]) >= 1: + for i in range(len(info["response_ranges"])): + info["response_ranges"][i][0] += infos["response_ranges"][-1][-1][1] # end_index of the last response of the last item + info["response_ranges"][i][1] += infos["response_ranges"][-1][-1][1] + infos["response_ranges"].append(info["response_ranges"]) + index += 1 + + packed_input_ids = torch.cat(packed_input_ids, dim=0).unsqueeze(0) + packed_attention_masks = torch.cat(packed_attention_masks, dim=0).unsqueeze(0) + + if ( + self.multiple_of > 1 and packed_input_ids.numel() % self.multiple_of != 0 + ): # not divisible by multiple_of; here we align for grouping + padding_len = self.multiple_of - (packed_input_ids.numel() % self.multiple_of) + packed_input_ids = F.pad(packed_input_ids, (0, padding_len), value=self.tokenizer.pad_token_id) + packed_attention_masks = F.pad(packed_attention_masks, (0, padding_len), value=0) + + return prompt_ids_lens, packed_input_ids, packed_attention_masks, infos diff --git a/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/datasets/unpaired_preference_dataset.py b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/datasets/unpaired_preference_dataset.py new file mode 100644 index 0000000000..47e9fd3021 --- /dev/null +++ b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/datasets/unpaired_preference_dataset.py @@ -0,0 +1,143 @@ +from typing import Callable + +import torch +from torch.utils.data import Dataset + +from .utils import zero_pad_sequences + + +def preprocess_data( + data, input_template=None, input_key=None, output_key=None, label_key=None, apply_chat_template=None +): + """ + Preprocess data from raw dataset to prompt, response, label + + Args: + data: raw data from dataset + """ + label = data[label_key] + + if apply_chat_template: + if output_key: + prompt = apply_chat_template(data[input_key], tokenize=False, add_generation_prompt=True) + response = apply_chat_template(data[input_key] + data[output_key], tokenize=False)[len(prompt) :] + else: + prompt = apply_chat_template(data[input_key][:-1], tokenize=False, add_generation_prompt=True) + response = apply_chat_template(data[input_key], tokenize=False)[len(prompt) :] + else: + prompt = data[input_key] + response = data[output_key] + if input_template: + prompt = input_template.format(prompt) + return prompt, response, label + + +class UnpairedPreferenceDataset(Dataset): + """ + Unpaired preference dataset for algorithm, like KTO + + Args: + dataset: raw dataset + self.tokenizer: self.tokenizer for model + self.max_length: max length of input + """ + + def __init__( + self, dataset, tokenizer: Callable, max_length: int, strategy, input_template=None, num_processors=8 + ) -> None: + super().__init__() + self.tokenizer = tokenizer + self.strategy = strategy + self.max_length = max_length + + # chat_template + self.input_template = input_template + self.input_key = getattr(self.strategy.args, "input_key", None) + self.output_key = getattr(self.strategy.args, "output_key", None) + self.label_key = getattr(self.strategy.args, "label_key", None) + self.apply_chat_template = getattr(self.strategy.args, "apply_chat_template", False) + + if self.apply_chat_template: + self.apply_chat_template = self.tokenizer.apply_chat_template + tokenizer_chat_template = getattr(self.strategy.args, "tokenizer_chat_template", None) + if tokenizer_chat_template: + self.tokenizer.chat_template = tokenizer_chat_template + + # Parallel loading datasets + processed_dataset = dataset.map( + self.process_data, remove_columns=dataset.column_names, num_proc=num_processors + ) + + # Filter out None values if necessary + processed_dataset = processed_dataset.filter(lambda x: x["prompt"] is not None) + + # Store the processed data in class attributes + self.prompts = processed_dataset["prompt"] + self.responses = processed_dataset["response"] + self.labels = processed_dataset["label"] + self.prompt_ids_lens = processed_dataset["prompt_ids_len"] + + def process_data(self, data): + prompt, response, label = preprocess_data( + data, self.input_template, self.input_key, self.output_key, self.label_key, self.apply_chat_template + ) + prompt_token = self.tokenizer( + prompt, + max_length=self.max_length, + padding=False, + truncation=True, + return_tensors="pt", + add_special_tokens=False, + ) + prompt_ids_len = prompt_token["attention_mask"].int().sum().item() + + # filter the sample whose length is greater than max_length (2 for answer length) + if prompt_ids_len >= self.max_length - 2: + prompt = None + + return {"prompt": prompt, "response": response, "label": label, "prompt_ids_len": prompt_ids_len} + + def __len__(self): + return len(self.prompts) + + def __getitem__(self, index): + return self.prompts[index], self.responses[index], self.labels[index], self.prompt_ids_lens[index] + + def collate_fn(self, item_list): + def tokenizer(prompt, response): + text = (prompt + response).rstrip("\n") + if not text.endswith(self.tokenizer.eos_token): + text += " " + self.tokenizer.eos_token + inputs = self.tokenizer( + text, + max_length=self.max_length, + padding=False, + truncation=True, + return_tensors="pt", + add_special_tokens=False, + ) + + inputs["input_ids"][0][-1] = self.tokenizer.eos_token_id + inputs["attention_mask"][0][-1] = True + return inputs["input_ids"], inputs["attention_mask"] + + tot_ids, tot_masks, tot_labels, prompt_ids_lens = [], [], [], [] + for prompt, response, label, prompt_ids_len in item_list: + input_ids, attention_mask = tokenizer(prompt, response) + tot_ids.append(input_ids) + tot_masks.append(attention_mask) + tot_labels.append(label) + prompt_ids_lens.append(prompt_ids_len) + + # add unmatched y'| x (used to estimate the KL divergence between policy and reference) + for idx in range(len(item_list)): + next_idx = (idx + 1) % len(item_list) + input_ids, attention_mask = tokenizer(item_list[idx][0], item_list[next_idx][1]) + tot_ids.append(input_ids) + tot_masks.append(attention_mask) + tot_labels.append(-1) + prompt_ids_lens.append(item_list[idx][3]) + + input_ids = zero_pad_sequences(tot_ids, side="right", value=self.tokenizer.pad_token_id) + attention_mask = zero_pad_sequences(tot_masks, side="right") + return input_ids, attention_mask, torch.LongTensor(tot_labels), prompt_ids_lens diff --git a/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/datasets/utils.py b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/datasets/utils.py new file mode 100644 index 0000000000..0d11efc585 --- /dev/null +++ b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/datasets/utils.py @@ -0,0 +1,17 @@ +import torch +import torch.nn.functional as F + + +def zero_pad_sequences(sequences, side: str = "left", value=0): + assert side in ("left", "right") + max_len = max(seq.size(-1) for seq in sequences) + padded_sequences = [] + for seq in sequences: + pad_len = max_len - seq.size(-1) + padding = (pad_len, 0) if side == "left" else (0, pad_len) + padded_sequences.append(F.pad(seq, padding, value=value)) + return torch.stack(padded_sequences, dim=0) + + +def exist_and_not_none(d, key): + return key in d and not d[key] is None diff --git a/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/models/__init__.py b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/models/__init__.py new file mode 100644 index 0000000000..d26fbc7edd --- /dev/null +++ b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/models/__init__.py @@ -0,0 +1,29 @@ +from .actor import Actor +from .loss import ( + DPOLoss, + GPTLMLoss, + KDLoss, + KTOLoss, + LogExpLoss, + PairWiseLoss, + PolicyLoss, + PRMLoss, + ValueLoss, + VanillaKTOLoss, +) +from .model import get_llm_for_sequence_regression + +__all__ = [ + "Actor", + "DPOLoss", + "GPTLMLoss", + "KDLoss", + "KTOLoss", + "LogExpLoss", + "PairWiseLoss", + "PolicyLoss", + "PRMLoss", + "ValueLoss", + "VanillaKTOLoss", + "get_llm_for_sequence_regression", +] diff --git a/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/models/actor.py b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/models/actor.py new file mode 100644 index 0000000000..68009603c6 --- /dev/null +++ b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/models/actor.py @@ -0,0 +1,242 @@ +from typing import Optional, Tuple, Union + +import torch +import torch.distributed as dist +import torch.nn as nn +from peft import LoraConfig, TaskType, get_peft_model +from peft.tuners.lora import LoraLayer +from transformers import AutoModelForCausalLM, BitsAndBytesConfig +from transformers.integrations.deepspeed import HfDeepSpeedConfig + +from .ring_attn_utils import convert_ring_attn_params +from .utils import log_probs_from_logits, reset_position_ids + + +class Actor(nn.Module): + """ + Base class for Actor models in reinforcement learning. + + This class serves as a foundation for implementing various actor models, which are responsible for selecting actions based on the policy learned from the environment. + + Args: + pretrain_or_model (nn.Module): A pretrained model or a new model instance to be used as the actor. + use_flash_attention_2 (bool, optional): Whether to utilize Flash Attention 2.0 for improved performance. Defaults to False. + bf16 (bool, optional): Enable bfloat16 precision for model computations. Defaults to True. + load_in_4bit (bool, optional): Load the model in 4-bit precision. Defaults to False. + lora_rank (int, optional): Rank for LoRA adaptation. Defaults to 0. + lora_alpha (int, optional): Alpha parameter for LoRA. Defaults to 16. + lora_dropout (float, optional): Dropout rate for LoRA layers. Defaults to 0. + target_modules (list, optional): List of target modules for applying LoRA. Defaults to None. + ds_config (dict, optional): Configuration for DeepSpeed, enabling model partitioning across multiple GPUs. Defaults to None. + device_map (dict, optional): Device mapping for loading the model onto specific devices. Defaults to None. + packing_samples (bool, optional): Whether to pack samples during training. Defaults to False. + """ + + def __init__( + self, + pretrain_or_model, + use_flash_attention_2=False, + bf16=True, + load_in_4bit=False, + lora_rank=0, + lora_alpha=16, + lora_dropout=0, + target_modules=None, + ds_config=None, + device_map=None, + packing_samples=False, + **kwargs, + ) -> None: + super().__init__() + + if isinstance(pretrain_or_model, str): + attn_implementation = "flash_attention_2" if use_flash_attention_2 else "eager" + + # Note: dschf is defined in function scope to avoid global effects + # https://huggingface.co/docs/transformers/deepspeed#non-trainer-deepspeed-integration + if ds_config is not None and ds_config["zero_optimization"]["stage"] == 3: + dschf = HfDeepSpeedConfig(ds_config) + else: + dschf = None + + if load_in_4bit: + assert bf16, "we only support bnb_4bit_compute_dtype = bf16" + nf4_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_use_double_quant=True, + bnb_4bit_compute_dtype=torch.bfloat16, + ) + else: + nf4_config = None + + self.model = AutoModelForCausalLM.from_pretrained( + pretrain_or_model, + trust_remote_code=True, + attn_implementation=attn_implementation, + quantization_config=nf4_config, + torch_dtype=torch.bfloat16 if bf16 else "auto", + device_map=device_map, + ) + + # LoRA + if lora_rank > 0: + # https://github.com/huggingface/peft/issues/137 + self.model.enable_input_require_grads() + lora_config = LoraConfig( + task_type=TaskType.CAUSAL_LM, + r=lora_rank, + lora_alpha=lora_alpha, + target_modules=target_modules, + lora_dropout=lora_dropout, + bias="none", + ) + self.model = get_peft_model(self.model, lora_config) + + if load_in_4bit: + for name, module in self.model.named_modules(): + if isinstance(module, LoraLayer): + module = module.to(torch.bfloat16) + if "norm" in name: + module = module.to(torch.float32) + if "lm_head" in name or "embed_tokens" in name: + if hasattr(module, "weight"): + module = module.to(torch.bfloat16) + + # MoE - balancing loss + model_config = self.model.config.to_dict() + if "output_router_logits" in model_config: + print("[MoE] set output_router_logits as True") + self.model.config.output_router_logits = True + + # https://github.com/huggingface/transformers/issues/26877 + # Use `model.generate(use_cache=True)` instead.` + self.model.config.use_cache = False + + # packing samples using Flash Attention 2 + self.packing_samples = packing_samples + else: + self.model = pretrain_or_model + + @torch.no_grad() + def generate(self, input_ids: torch.Tensor, **kwargs) -> Union[ + Tuple[torch.LongTensor, torch.LongTensor], + Tuple[torch.LongTensor, torch.LongTensor, torch.BoolTensor], + ]: + generate_args = { + "input_ids": input_ids, + "top_k": kwargs.get("top_k", None), + "top_p": kwargs.get("top_p", None), + "do_sample": kwargs.get("do_sample", True), + "early_stopping": True, + "temperature": kwargs.get("temperature", 1), + "use_cache": True, + "num_beams": kwargs.get("num_beams", 1), + "attention_mask": kwargs.get("attention_mask"), + "eos_token_id": kwargs.get("eos_token_id"), + "pad_token_id": kwargs.get("pad_token_id"), + "min_new_tokens": kwargs.get("min_new_tokens", 1), + } + + if kwargs.get("max_new_tokens", None): + generate_args["max_new_tokens"] = kwargs.get("max_new_tokens") + if kwargs.get("max_length", None): + generate_args["max_length"] = kwargs.get("max_length") + + # Call generate + sequences = self.model.generate(**generate_args) + + # Prepare mask tensor + eos_token_id = generate_args["eos_token_id"] + pad_token_id = generate_args["pad_token_id"] + + return self.process_sequences(sequences, input_ids.size(1), eos_token_id, pad_token_id) + + def process_sequences(self, sequences: torch.Tensor, input_len, eos_token_id, pad_token_id): + attention_mask = (sequences.ne(eos_token_id) & sequences.ne(pad_token_id)).to(dtype=torch.long) + seq_length = attention_mask.size(1) + + # The following code is equivalent to: + # + # for i in range(attention_mask.size(0)): + # for t in reversed(range(seq_length)): + # if attention_mask[i][t] > 0.5: + # attention_mask[i][min(t + 1, seq_length - 1)] = True + # sequences[i][min(t + 1, seq_length - 1)] = eos_token_id + # break + # + eos_indices = seq_length - attention_mask.long().fliplr().argmax(dim=1, keepdim=True).clamp(min=1) + sequences.scatter_(dim=1, index=eos_indices, value=eos_token_id) + + # For Llama3 and Qwen2 models, there are some eos_tokens in the middle of the prompt. + first_token_indices = attention_mask.long().argmax(dim=1, keepdim=True) + mask = torch.arange(seq_length).unsqueeze(0).expand(sequences.size(0), -1).to(device=sequences.device) + attention_mask = (mask >= first_token_indices) & (mask <= eos_indices).to(dtype=torch.long) + + # in RL, state_i (current token) + action_i (next token) -> state_i+1 (next token) + state_seq = sequences[:, input_len - 1 : -1] + action_mask = state_seq.ne(eos_token_id) & state_seq.ne(pad_token_id) + action_mask[:, 0] = 1 + + return sequences, attention_mask, action_mask + + def forward( + self, + sequences: torch.LongTensor, + num_actions: Optional[Union[int, list[int]]] = None, + attention_mask: Optional[torch.Tensor] = None, + return_output=False, + ring_attn_group: Optional[dist.ProcessGroup] = None, + packed_seq_lens: Optional[list[int]] = None, + ) -> torch.Tensor: + """Returns action log probs""" + if not self.packing_samples: + # https://github.com/OpenRLHF/OpenRLHF/issues/217 + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + else: + # convert attention_mask to position_ids + if ring_attn_group is not None: + sequences, attention_mask, position_ids = convert_ring_attn_params( + sequences, attention_mask, packed_seq_lens, ring_attn_group + ) + else: + position_ids = reset_position_ids(attention_mask) + # explicitly ignore attention_mask for packing_samples + attention_mask = None + + output = self.model(sequences, attention_mask=attention_mask, position_ids=position_ids) + # https://github.com/OpenRLHF/OpenRLHF/pull/634 + output["logits"] = output["logits"].to(torch.float32) + + if num_actions is None: + assert return_output + return output + + log_probs = log_probs_from_logits(output["logits"][:, :-1, :], sequences[:, 1:]) + + if not self.packing_samples: + action_log_probs = log_probs[:, -num_actions:] + else: + assert isinstance(num_actions, list) and len(num_actions) == len(packed_seq_lens) + action_log_probs = [] + offset = 0 + for num_action, seq_len in zip(num_actions, packed_seq_lens): + start, end = max(0, offset + seq_len - num_action - 1), offset + seq_len - 1 + action_log_probs.append(log_probs[:, start:end]) + offset += seq_len + action_log_probs = torch.cat(action_log_probs, dim=1) + + if return_output: + return (action_log_probs, output) + else: + return action_log_probs + + def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs={"use_reentrant": False}): + self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=gradient_checkpointing_kwargs) + + def gradient_checkpointing_disable(self): + self.model.gradient_checkpointing_disable() + + def print_trainable_parameters(self): + self.model.print_trainable_parameters() diff --git a/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/models/loss.py b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/models/loss.py new file mode 100644 index 0000000000..df1215af56 --- /dev/null +++ b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/models/loss.py @@ -0,0 +1,326 @@ +from typing import Optional, Tuple + +import torch +import torch.distributed as dist +import torch.nn as nn +import torch.nn.functional as F + +from .utils import masked_mean + + +class GPTLMLoss(nn.Module): + """ + GPT Language Model Loss + """ + + def __init__(self, ring_attn_group=None): + super().__init__() + self.IGNORE_INDEX = -100 + self.loss = nn.CrossEntropyLoss(ignore_index=self.IGNORE_INDEX) + + self.ring_attn_group = ring_attn_group + if self.ring_attn_group: + self.ring_attn_rank = dist.get_rank(self.ring_attn_group) + self.ring_attn_world_size = dist.get_world_size(self.ring_attn_group) + + def forward(self, logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: + # RingAttention + if self.ring_attn_group is not None: + total_seq_len = labels.size(-1) + seq_len_per_process = total_seq_len // self.ring_attn_world_size + start_idx = self.ring_attn_rank * seq_len_per_process + end_idx = min(start_idx + seq_len_per_process, total_seq_len) + labels = labels[..., start_idx:end_idx] + + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + + # if labels are all IGNORE_INDEX, then nn.CrossEntropyLoss will be nan + if torch.all(shift_labels == self.IGNORE_INDEX): + # Use mean of logits multiplied by 0 to maintain gradient flow + loss = shift_logits.mean() * 0 + else: + loss = self.loss(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + + dist.all_reduce(loss, op=dist.ReduceOp.SUM, group=self.ring_attn_group) + loss = loss / self.ring_attn_world_size + else: + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + + loss = self.loss(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + + return loss + + +class PolicyLoss(nn.Module): + """ + Policy Loss for PPO + """ + + def __init__(self, clip_eps: float = 0.2) -> None: + super().__init__() + self.clip_eps = clip_eps + + def forward( + self, + log_probs: torch.Tensor, + old_log_probs: torch.Tensor, + advantages: torch.Tensor, + action_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + ratio = (log_probs - old_log_probs).exp() + surr1 = ratio * advantages + surr2 = ratio.clamp(1 - self.clip_eps, 1 + self.clip_eps) * advantages + loss = -torch.min(surr1, surr2) + loss = masked_mean(loss, action_mask, dim=-1).mean() + return loss + + +class ValueLoss(nn.Module): + """ + Value Loss for PPO + """ + + def __init__(self, clip_eps: float = None) -> None: + super().__init__() + self.clip_eps = clip_eps + + def forward( + self, + values: torch.Tensor, + old_values: torch.Tensor, + returns: torch.Tensor, + action_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if self.clip_eps is not None: + values_clipped = old_values + (values - old_values).clamp(-self.clip_eps, self.clip_eps) + surr1 = (values_clipped - returns) ** 2 + surr2 = (values - returns) ** 2 + loss = torch.max(surr1, surr2) + else: + loss = (values - returns) ** 2 + + loss = masked_mean(loss, action_mask, dim=-1).mean() + return 0.5 * loss + + +class PairWiseLoss(nn.Module): + """ + Pairwise Loss for Reward Model + """ + + def forward( + self, chosen_reward: torch.Tensor, reject_reward: torch.Tensor, margin: torch.Tensor = None + ) -> torch.Tensor: + if margin is not None: + loss = -F.logsigmoid(chosen_reward - reject_reward - margin) + else: + loss = -F.logsigmoid(chosen_reward - reject_reward) + return loss.mean() + + +class LogExpLoss(nn.Module): + """ + Pairwise Loss for Reward Model + Details: https://arxiv.org/abs/2204.05862 + """ + + def forward( + self, chosen_reward: torch.Tensor, reject_reward: torch.Tensor, margin: torch.Tensor = None + ) -> torch.Tensor: + loss = torch.log(1 + torch.exp(reject_reward - chosen_reward)).mean() + return loss + + +class DPOLoss(nn.Module): + """ + DPO Loss + """ + + def __init__(self, beta: float, label_smoothing: float = 0.0, ipo: bool = False) -> None: + super().__init__() + self.beta = beta + self.label_smoothing = label_smoothing + self.ipo = ipo + + def forward( + self, + policy_chosen_logps: torch.Tensor, + policy_rejected_logps: torch.Tensor, + reference_chosen_logps: torch.Tensor, + reference_rejected_logps: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + pi_logratios = policy_chosen_logps - policy_rejected_logps + ref_logratios = reference_chosen_logps - reference_rejected_logps + logits = pi_logratios - ref_logratios + + if self.ipo: + losses = (logits - 1 / (2 * self.beta)) ** 2 # Eq. 17 of https://arxiv.org/pdf/2310.12036v2.pdf + else: + # Eq. 3 https://ericmitchell.ai/cdpo.pdf; label_smoothing=0 gives original DPO (Eq. 7 of https://arxiv.org/pdf/2305.18290.pdf) + losses = ( + -F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing) + - F.logsigmoid(-self.beta * logits) * self.label_smoothing + ) + + loss = losses.mean() + chosen_rewards = self.beta * (policy_chosen_logps - reference_chosen_logps).detach() + rejected_rewards = self.beta * (policy_rejected_logps - reference_rejected_logps).detach() + + return loss, chosen_rewards, rejected_rewards + + +# Adapted from https://github.com/ContextualAI/HALOs/blob/ca9b7e3eeea220c0944ad8095d641da33f907a7e/trainers.py#L742 +class VanillaKTOLoss(nn.Module): + """ + KTO loss for even sampling + """ + + def __init__(self, beta: float) -> None: + super().__init__() + self.beta = beta + + def forward( + self, + policy_chosen_logps: torch.FloatTensor, + policy_rejected_logps: torch.FloatTensor, + reference_chosen_logps: torch.FloatTensor, + reference_rejected_logps: torch.FloatTensor, + ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: + chosen_KL = (policy_chosen_logps - reference_chosen_logps).mean().clamp(min=0) + rejected_KL = (policy_rejected_logps - reference_rejected_logps).mean().clamp(min=0) + + chosen_logratios = policy_chosen_logps - reference_chosen_logps + rejected_logratios = policy_rejected_logps - reference_rejected_logps + + losses = torch.cat( + ( + 1 - F.sigmoid(self.beta * (chosen_logratios - rejected_KL)), + 1 - F.sigmoid(self.beta * (chosen_KL - rejected_logratios)), + ), + 0, + ).mean() + + chosen_rewards = self.beta * (policy_chosen_logps - reference_chosen_logps).detach() + rejected_rewards = self.beta * (policy_rejected_logps - reference_rejected_logps).detach() + return losses, chosen_rewards, rejected_rewards + + +# Adapted from https://github.com/ContextualAI/HALOs/blob/ca9b7e3eeea220c0944ad8095d641da33f907a7e/trainers.py#L770 +class KTOLoss(nn.Module): + """ + KTO loss for uneven sampling + """ + + def __init__( + self, beta: float, desirable_weight: float, undesirable_weight: float, world_size: int, device: torch.device + ) -> None: + super().__init__() + self.beta = beta + self.world_size = world_size + self.device = device + self.desirable_weight = desirable_weight + self.undesirable_weight = undesirable_weight + + def forward( + self, + policy_chosen_logps: torch.FloatTensor, + policy_rejected_logps: torch.FloatTensor, + policy_KL_logps: torch.FloatTensor, + reference_chosen_logps: torch.FloatTensor, + reference_rejected_logps: torch.FloatTensor, + reference_KL_logps: torch.FloatTensor, + ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: + KL = (policy_KL_logps - reference_KL_logps).mean().detach() + # all_reduce sums up the KL estimates across all devices (gradient will also be scaled by world size) + dist.all_reduce(KL, op=dist.ReduceOp.SUM) + # take average (will also scale gradients appropriately) + KL = (KL / self.world_size).clamp(min=0) + + if policy_chosen_logps.shape[0] != 0: + chosen_logratios = policy_chosen_logps - reference_chosen_logps + chosen_losses = 1 - F.sigmoid(self.beta * (chosen_logratios - KL)) + chosen_rewards = self.beta * chosen_logratios.detach() + else: + # important to cast to policy_dtype; otherwise error will occur during all_gather + chosen_losses = torch.Tensor([]).to(policy_rejected_logps.dtype).to(self.device) + chosen_rewards = torch.Tensor([]).to(policy_rejected_logps.dtype).to(self.device) + + if policy_rejected_logps.shape[0] != 0: + rejected_logratios = policy_rejected_logps - reference_rejected_logps + rejected_losses = 1 - F.sigmoid(self.beta * (KL - rejected_logratios)) + rejected_rewards = self.beta * rejected_logratios.detach() + else: + # important to cast to policy_dtype; otherwise error will occur during all_gather + rejected_losses = torch.Tensor([]).to(policy_chosen_logps.dtype).to(self.device) + rejected_rewards = torch.Tensor([]).to(policy_chosen_logps.dtype).to(self.device) + + losses = torch.cat( + (self.desirable_weight * chosen_losses, self.undesirable_weight * rejected_losses), 0 + ).mean() + return losses, chosen_rewards, rejected_rewards, KL + + +# Adapted from https://github.com/microsoft/LMOps/blob/main/minillm/finetune.py#L166 +class KDLoss(nn.Module): + """ + Language Model Knowledge Distillation Loss + """ + + def __init__(self): + super().__init__() + self.IGNORE_INDEX = -100 + + def forward(self, logits: torch.Tensor, teacher_logits: torch.Tensor, label: torch.Tensor) -> torch.Tensor: + teacher_probs = F.softmax(teacher_logits, dim=-1, dtype=torch.float32) + inf_mask = torch.isinf(logits) + logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32) + prod_probs = torch.masked_fill(teacher_probs * logprobs, inf_mask, 0) + x = torch.sum(prod_probs, dim=-1).view(-1) + mask = (label != self.IGNORE_INDEX).int() + distil_loss = -torch.sum(x * mask.view(-1), dim=0) / torch.sum(mask.view(-1), dim=0) + + return distil_loss + + +class PRMLoss(nn.Module): + """ + Process Reward Model Loss + """ + + def __init__(self, placeholder_token_id: int, reward_token_ids: Optional[list[int]] = None): + super().__init__() + self.IGNORE_INDEX = -100 + self.loss = nn.CrossEntropyLoss(ignore_index=self.IGNORE_INDEX) + self.placeholder_token_id = placeholder_token_id + self.reward_token_ids = reward_token_ids + + def forward(self, inputs: torch.Tensor, logits: torch.Tensor, labels: torch.Tensor, *, return_acc: bool = False): + placeholder_mask = inputs == self.placeholder_token_id + logits = logits[placeholder_mask] + labels = labels[placeholder_mask] + + if labels.dtype == torch.float: + # soft label + assert len(self.reward_token_ids) == 2, "reward_token_ids should have 2 tokens for soft labels" + logits = logits[..., self.reward_token_ids] + positive_labels = labels.to(logits.dtype) + negative_labels = 1 - positive_labels + negative_labels[positive_labels != -100] = 1 - positive_labels[positive_labels != -100] + labels = torch.stack([positive_labels, negative_labels], dim=-1) + elif self.reward_token_ids is not None: + # hard label with reward_token_ids set. (otherwise the whole vocab will be trained together.) + logits = logits[..., self.reward_token_ids] + # this is slow.... + for i, token in enumerate(self.reward_token_ids): + labels = torch.where(labels == token, i, labels) + + loss = self.loss(logits, labels) + if not return_acc: + return loss + + if labels.dtype == logits.dtype: + labels = labels.argmax(dim=-1) + acc = (logits.argmax(dim=-1) == labels).float().mean() + return loss, acc diff --git a/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/models/model.py b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/models/model.py new file mode 100644 index 0000000000..3d2102dc94 --- /dev/null +++ b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/models/model.py @@ -0,0 +1,304 @@ +from typing import Optional, Union + +import deepspeed +import torch +import torch.nn as nn +from flash_attn.utils.distributed import all_gather +from peft import LoraConfig, get_peft_model +from peft.tuners.lora import LoraLayer +from transformers import AutoConfig, AutoModel, BitsAndBytesConfig +from transformers.integrations.deepspeed import HfDeepSpeedConfig + +from openrlhf.utils.logging_utils import init_logger + +from .ring_attn_utils import convert_ring_attn_params +from .utils import reset_position_ids + +logger = init_logger(__name__) + + +# Construct transformer with a value head for sequence classification. +# https://github.com/huggingface/transformers/blob/405b56269812056d9593869e22b7b264d806cb1e/src/transformers/models/llama/modeling_llama.py#L1254 +def get_llm_for_sequence_regression( + model_name_or_path: str, + model_type: str, + *, + bf16=True, + load_in_4bit=False, + lora_rank=0, + lora_alpha=16, + target_modules=None, + lora_dropout=0, + normalize_reward=False, + use_flash_attention_2=False, + ds_config: dict = None, + init_value_head: bool = False, + value_head_prefix="score", + device_map=None, + packing_samples=False, + **kwargs, +) -> nn.Module: + """Retrieve a transformer model with a sequence regression head on top. + + This function loads a pretrained transformer model and attaches a linear layer for sequence regression. + + Args: + model_name_or_path (str): Path to the pretrained model. + model_type (str): Type of the model, either "reward" or "critic". + bf16 (bool, optional): Enable bfloat16 precision. Defaults to True. + load_in_4bit (bool, optional): Load the model in 4-bit precision. Defaults to False. + lora_rank (int, optional): Rank for LoRA adaptation. Defaults to 0. + lora_alpha (int, optional): Alpha parameter for LoRA. Defaults to 16. + target_modules (list, optional): List of target modules for LoRA. Defaults to None. + lora_dropout (float, optional): Dropout rate for LoRA layers. Defaults to 0. + normalize_reward (bool, optional): Normalize reward values. Defaults to False. + use_flash_attention_2 (bool, optional): Use Flash Attention 2.0. Defaults to False. + ds_config (dict, optional): Deepspeed configuration for model partitioning across multiple GPUs when ZeRO-3 is enabled. Defaults to None. + init_value_head (bool, optional): Initialize the value head. Defaults to False. + value_head_prefix (str, optional): Prefix for the value head. Defaults to "score". + device_map (dict, optional): Map of devices for model loading. Defaults to None. + packing_samples (bool, optional): Whether to pack samples during training. Defaults to False. + + Returns: + nn.Module: A pretrained transformer model with a sequence regression head. + """ + assert ( + model_type == "critic" or model_type == "reward" + ), f"invalid model_type: {model_type}, should be critic or reward." + + config = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True) + config.normalize_reward = normalize_reward + config._attn_implementation = "flash_attention_2" if use_flash_attention_2 else "eager" + + # Prioritize using the value_head_prefix in the model configuration. + value_head_prefix = getattr(config, "value_head_prefix", value_head_prefix) + logger.info(f"set value_head_prefix to `{value_head_prefix}`") + + base_class = AutoModel._model_mapping[type(config)] + base_pretrained_class = base_class.__base__ + if model_type == "reward": + cls_class = _get_reward_model(base_pretrained_class, base_class, value_head_prefix, packing_samples) + else: + cls_class = _get_critic_model(base_pretrained_class, base_class, value_head_prefix, packing_samples) + + # Note: dschf is defined in function scope to avoid global effects + # https://huggingface.co/docs/transformers/main_classes/deepspeed#nontrainer-deepspeed-integration + if ds_config is not None and ds_config["zero_optimization"]["stage"] == 3: + dschf = HfDeepSpeedConfig(ds_config) + else: + dschf = None + + if load_in_4bit: + assert bf16, "we only support bnb_4bit_compute_dtype = bf16" + nf4_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_use_double_quant=True, + bnb_4bit_compute_dtype=torch.bfloat16, + ) + else: + nf4_config = None + + model = cls_class.from_pretrained( + model_name_or_path, + config=config, + trust_remote_code=True, + torch_dtype=torch.bfloat16 if bf16 else "auto", + quantization_config=nf4_config, + device_map=device_map, + **kwargs, + ) + + # LoRA + if lora_rank > 0: + model.enable_input_require_grads() + lora_config = LoraConfig( + r=lora_rank, + lora_alpha=lora_alpha, + target_modules=target_modules, + lora_dropout=lora_dropout, + bias="none", + ) + model = get_peft_model(model, lora_config) + + if load_in_4bit: + for name, module in model.named_modules(): + if isinstance(module, LoraLayer): + module = module.to(torch.bfloat16) + if "norm" in name: + module = module.to(torch.float32) + if value_head_prefix in name or "embed_tokens" in name: + if hasattr(module, "weight"): + module = module.to(torch.bfloat16) + + # MoE - balancing loss + model_config = model.config.to_dict() + if "output_router_logits" in model_config: + print("[MoE] set output_router_logits as True") + model.config.output_router_logits = True + + # https://github.com/huggingface/transformers/issues/26877 + model.config.use_cache = False + + # NOTE: For reward model training only, intialize value_head manually + # because deepspeed.zero.Init() will not intialize them. + # TODO: Find a better way to clarify reward model training. + if init_value_head: + value_head = getattr(model, value_head_prefix) + if dschf is not None: + logger.info("initialize value_head for ZeRO-3 reward model training.") + with deepspeed.zero.GatheredParameters([value_head.weight], modifier_rank=0): + if torch.distributed.get_rank() == 0: + value_head.weight.data.normal_(mean=0.0, std=1 / (config.hidden_size + 1)) + else: + value_head.weight.data.normal_(mean=0.0, std=1 / (config.hidden_size + 1)) + + return model + + +def _get_reward_model(base_pretrained_model, base_llm_model, value_head_prefix="score", packing_samples=False): + class RewardModel(base_pretrained_model): + supports_gradient_checkpointing = True + + def __init__(self, config: AutoConfig): + super().__init__(config) + setattr(self, self.base_model_prefix, base_llm_model(config)) + + self.value_head_prefix = value_head_prefix + setattr(self, value_head_prefix, nn.Linear(config.hidden_size, 1, bias=False)) + + self.packing_samples = packing_samples + + # mean std + self.normalize_reward = config.normalize_reward + self.register_buffer("mean", torch.zeros(1), persistent=False) + self.register_buffer("std", torch.ones(1), persistent=False) + + # load mean/std from config.json + if hasattr(config, "mean"): + self.mean[0] = config.mean + self.std[0] = config.std + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + return_output=False, + ring_attn_group=None, + packed_seq_lens=None, + ) -> torch.Tensor: + if not self.packing_samples: + # https://github.com/OpenRLHF/OpenRLHF/issues/217 + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + else: + # convert attention_mask to position_ids + if ring_attn_group is not None: + input_ids, attention_mask, position_ids = convert_ring_attn_params( + input_ids, attention_mask, packed_seq_lens, ring_attn_group + ) + else: + position_ids = reset_position_ids(attention_mask) + # explicitly ignore attention_mask for packing_samples + attention_mask = None + + outputs = getattr(self, self.base_model_prefix)( + input_ids, attention_mask=attention_mask, position_ids=position_ids + ) + last_hidden_states = outputs["last_hidden_state"] + values = getattr(self, self.value_head_prefix)(last_hidden_states).squeeze(-1) + + if self.packing_samples: + if ring_attn_group is not None: + reward = all_gather(values, ring_attn_group).reshape(1, -1) + else: + reward = values + # TODO: convert packed_seq_lens into torch tensor in advance + packed_seq_lens = torch.tensor(packed_seq_lens, device=values.device) + eos_indices = packed_seq_lens.cumsum(dim=0) - 1 + reward = reward.squeeze(0).gather(dim=0, index=eos_indices) + else: + eos_indices = attention_mask.size(1) - 1 - attention_mask.long().fliplr().argmax(dim=1, keepdim=True) + reward = values.gather(dim=1, index=eos_indices).squeeze(1) + + if not self.training and self.normalize_reward: + reward = (reward - self.mean) / self.std + + return (reward, outputs) if return_output else reward + + return RewardModel + + +def _get_critic_model(base_pretrained_model, base_llm_model, value_head_prefix="score", packing_samples=False): + class CriticModel(base_pretrained_model): + supports_gradient_checkpointing = True + + def __init__(self, config: AutoConfig): + super().__init__(config) + setattr(self, self.base_model_prefix, base_llm_model(config)) + + self.value_head_prefix = value_head_prefix + setattr(self, value_head_prefix, nn.Linear(config.hidden_size, 1, bias=False)) + + self.packing_samples = packing_samples + + # mean std + self.normalize_reward = config.normalize_reward + self.register_buffer("mean", torch.zeros(1), persistent=False) + self.register_buffer("std", torch.ones(1), persistent=False) + + # load mean/std from config.json + if hasattr(config, "mean"): + self.mean[0] = config.mean + self.std[0] = config.std + + def forward( + self, + input_ids: torch.LongTensor = None, + num_actions: Optional[Union[int, list[int]]] = None, + attention_mask: Optional[torch.Tensor] = None, + return_output=False, + packed_seq_lens=None, + ) -> torch.Tensor: + if not self.packing_samples: + # https://github.com/OpenRLHF/OpenRLHF/issues/217 + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + else: + # convert attention_mask to position_ids + position_ids = reset_position_ids(attention_mask) + # explicitly ignore attention_mask for packing_samples + attention_mask = None + + outputs = getattr(self, self.base_model_prefix)( + input_ids, attention_mask=attention_mask, position_ids=position_ids + ) + last_hidden_states = outputs["last_hidden_state"] + values = getattr(self, self.value_head_prefix)(last_hidden_states).squeeze(-1)[:, :-1] + + # normalize reward + if self.normalize_reward: + values = (values - self.mean) / self.std + + if num_actions is None: + assert return_output + return outputs + + if not self.packing_samples: + action_values = values[:, -num_actions:] + else: + assert isinstance(num_actions, list) and len(num_actions) == len(packed_seq_lens) + action_values = [] + offset = 0 + for num_action, seq_len in zip(num_actions, packed_seq_lens): + start, end = max(0, offset + seq_len - num_action - 1), offset + seq_len - 1 + action_values.append(values[:, start:end]) + offset += seq_len + action_values = torch.cat(action_values, dim=1) + + if return_output: + return (action_values, outputs) + else: + return action_values + + return CriticModel diff --git a/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/models/ring_attn_utils.py b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/models/ring_attn_utils.py new file mode 100644 index 0000000000..ff76034625 --- /dev/null +++ b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/models/ring_attn_utils.py @@ -0,0 +1,74 @@ +import torch +import torch.distributed as dist +import torch.nn.functional as F + + +RING_ATTN_GROUP = None + + +def set_ring_attn_group(group): + global RING_ATTN_GROUP + RING_ATTN_GROUP = group + + +def get_ring_attn_group(): + return RING_ATTN_GROUP + + +def reset_ring_attn_position_ids(start, end, packed_seq_lens): + """ + Calculate position ids for packed_seq_ids[start:end]. + For example, if the packed_seq_lens is [3, 2, 4, 1], start=2, end=8, + the position ids will be [2, 0, 1, 0, 1, 2]. + + Args: + start: the start position + end: the end position + packed_seq_lens: the sequence lengths of packed sequences + """ + position_ids = torch.zeros((1, end - start), dtype=torch.long, device=torch.cuda.current_device()) + offset = 0 + for seqlen in packed_seq_lens: + seq_start = max(offset, start) + seq_end = min(offset + seqlen, end) + if seq_start < seq_end: + position_ids[0, seq_start - start : seq_end - start] = torch.arange(seq_start - offset, seq_end - offset) + + offset += seqlen + if offset >= end: + break + return position_ids + + +def update_ring_attn_params(packed_seq_lens, total_seq_len): + """ + Calculate the cu_seqlens for the current forward pass and pass the value to + the substituted ring_flash_attn. + + Note that total_seq_len may be larger than the sum of packed_seq_lens because of padding. + """ + assert RING_ATTN_GROUP is not None + cu_seqlens = torch.cumsum( + torch.tensor(packed_seq_lens, device=torch.cuda.current_device(), dtype=torch.int32), + dim=-1, + dtype=torch.int32, + ) + cu_seqlens = F.pad(F.pad(cu_seqlens, (1, 0), value=0), (0, 1), value=total_seq_len) + + from ring_flash_attn import update_ring_flash_attn_params + + update_ring_flash_attn_params(cu_seqlens, RING_ATTN_GROUP) + + +def convert_ring_attn_params(sequences, attention_mask, packed_seq_lens, ring_attn_group): + # each rank within the ring group will process sequences[start:end] + ring_attn_rank = dist.get_rank(group=ring_attn_group) + ring_attn_size = dist.get_world_size(group=ring_attn_group) + total_seq_len = sequences.numel() + local_seq_len = total_seq_len // ring_attn_size + start, end = ring_attn_rank * local_seq_len, (ring_attn_rank + 1) * local_seq_len + sequences = sequences[:, start:end] + attention_mask = attention_mask[:, start:end] + position_ids = reset_ring_attn_position_ids(start, end, packed_seq_lens) + update_ring_attn_params(packed_seq_lens, total_seq_len) + return sequences, attention_mask, position_ids diff --git a/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/models/utils.py b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/models/utils.py new file mode 100644 index 0000000000..363f070b9b --- /dev/null +++ b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/models/utils.py @@ -0,0 +1,120 @@ +from typing import Optional, Tuple, Union + +import torch +import torch.nn.functional as F + + +def compute_approx_kl( + log_probs: torch.Tensor, + log_probs_base: torch.Tensor, + action_mask: Optional[torch.Tensor] = None, + use_kl_estimator_k3: bool = False, +) -> torch.Tensor: + """ + Compute the approximate KL divergence between two distributions. + Schulman blog: http://joschu.net/blog/kl-approx.html + + Args: + log_probs: Log probabilities of the new distribution. + log_probs_base: Log probabilities of the base distribution. + action_mask: Mask for actions. + """ + + log_ratio = log_probs.float() - log_probs_base.float() + if action_mask is not None: + log_ratio = log_ratio * action_mask + + # The k3 estimator is the non negative kl approximation in + # http://joschu.net/blog/kl-approx.html + # Besides non negative, it is also unbiased and have lower variance. + if use_kl_estimator_k3: + log_ratio = -log_ratio + log_ratio = log_ratio.exp() - 1 - log_ratio + + return log_ratio + + +def compute_reward( + r: Union[torch.Tensor, float], + kl_coef: float, + kl: Union[torch.Tensor, list[torch.Tensor]], + action_mask: Optional[torch.Tensor] = None, + num_actions: Optional[Union[int, list[int]]] = None, + reward_clip_range: Tuple[float, float] = None, +) -> Union[torch.Tensor, list[torch.Tensor]]: + if kl_coef <= 0.0: + kl_coef = 0.0 + + if reward_clip_range: + r = r.clamp(min=reward_clip_range[0], max=reward_clip_range[1]) + + if action_mask is not None: + kl_reward = -kl_coef * kl + # The following code is equivalent to: + # + # last_reward = torch.zeros_like(kl) + # for i in range(last_reward.size(0)): + # for t in reversed(range(last_reward.size(1))): + # if action_mask[i][t] > 0.5: + # last_reward[i][t] = r[i] + # break + # + eos_indices = action_mask.size(1) - 1 - action_mask.long().fliplr().argmax(dim=1, keepdim=True) + last_reward = torch.zeros_like(kl).scatter_(dim=1, index=eos_indices, src=r.unsqueeze(1).to(kl.dtype)) + + reward = last_reward + kl_reward + else: + # TODO: write a more efficient version + reward = [] + for i, (kl_seg, action_len) in enumerate(zip(kl, num_actions)): + kl_reward = -kl_coef * kl_seg + kl_reward[action_len - 1] += r[i] + reward.append(kl_reward) + + return reward + + +def log_probs_from_logits(logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: + log_probs = F.log_softmax(logits, dim=-1) + log_probs_labels = log_probs.gather(dim=-1, index=labels.unsqueeze(-1)) + return log_probs_labels.squeeze(-1) + + +def masked_mean(tensor: torch.Tensor, mask: Optional[torch.Tensor], dim: int = None) -> torch.Tensor: + if mask is None: + return tensor.mean(axis=dim) + return (tensor * mask).sum(axis=dim) / mask.sum(axis=dim) + + +def masked_normalize(tensor: torch.Tensor, mask: torch.Tensor, dim: int = 1, eps: float = 1e-8) -> torch.Tensor: + tensor = tensor * mask + mean = masked_mean(tensor, mask, dim=dim) + mean_centered = tensor - mean + var = masked_mean(mean_centered**2, mask, dim=dim) + return mean_centered * var.clamp(min=eps).rsqrt() + + +# Reset positions for packed samples +# For example +# Input: attention_mask = torch.tensor([[1, 1, 1, 2, 2, 2, 3, 3, 0]]) +# Output: position_ids = torch.tensor([[0, 1, 2, 0, 1, 2, 0, 1, 0]]) +def reset_position_ids(attention_mask): + position_ids = torch.zeros_like(attention_mask, dtype=torch.long) + for i in range(attention_mask.size(0)): + mask = attention_mask[i] + seq_num = mask.max().item() + for index in range(1, seq_num + 1): + sample_mask = mask == index + sample_length = sample_mask.sum().item() + position_ids[i, sample_mask] = torch.arange(sample_length, device=mask.device) + return position_ids + + +def unpacking_samples(values: torch.Tensor, packed_seqlens: list[int]): + values = values.squeeze(0) + unpacked_values = [] + offset = 0 + for seqlen in packed_seqlens: + unpacked_values.append(values[offset : offset + seqlen]) + offset += seqlen + return unpacked_values diff --git a/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/trainer/__init__.py b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/trainer/__init__.py new file mode 100644 index 0000000000..a26d247b63 --- /dev/null +++ b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/trainer/__init__.py @@ -0,0 +1,17 @@ +from .dpo_trainer import DPOTrainer +from .kd_trainer import KDTrainer +from .kto_trainer import KTOTrainer +from .ppo_trainer import PPOTrainer +from .prm_trainer import ProcessRewardModelTrainer +from .rm_trainer import RewardModelTrainer +from .sft_trainer import SFTTrainer + +__all__ = [ + "DPOTrainer", + "KDTrainer", + "KTOTrainer", + "PPOTrainer", + "ProcessRewardModelTrainer", + "RewardModelTrainer", + "SFTTrainer", +] diff --git a/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/trainer/dpo_trainer.py b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/trainer/dpo_trainer.py new file mode 100644 index 0000000000..55088cd559 --- /dev/null +++ b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/trainer/dpo_trainer.py @@ -0,0 +1,478 @@ +import os +from abc import ABC + +import torch +from flash_attn.utils.distributed import all_gather +from torch.nn import functional as F +from torch.optim import Optimizer +from tqdm import tqdm + +from openrlhf.models import DPOLoss +from openrlhf.utils.distributed_sampler import DistributedSampler + + +class DPOTrainer(ABC): + """ + Trainer for Direct Preference Optimization (DPO) training. + + Args: + model (torch.nn.Module): The primary model to be trained. + ref_model (torch.nn.Module): The reference model for comparing and guiding preference. + strategy (Strategy): The strategy to use for training. + tokenizer (Tokenizer): The tokenizer for processing input data. + optim (Optimizer): The optimizer for training the model. + train_dataloader (DataLoader): The dataloader for the training dataset. + eval_dataloader (DataLoader): The dataloader for the evaluation dataset. + scheduler (Scheduler): The learning rate scheduler to control learning rate during training. + max_norm (float, defaults to 0.5): Maximum gradient norm for gradient clipping. + beta (float, defaults to 0.01): Coefficient for regularizing the preference loss. + max_epochs (int, defaults to 2): Maximum number of training epochs. + save_hf_ckpt (bool): Whether to save huggingface-format model weight. + disable_ds_ckpt (bool): Whether not to save deepspeed-format model weight. (Deepspeed model weight is used for training recovery) + """ + + def __init__( + self, + model, + ref_model, + strategy, + tokenizer, + optim: Optimizer, + train_dataloader, + eval_dataloader, + scheduler, + max_norm=0.5, + beta=0.01, + max_epochs: int = 2, + save_hf_ckpt: bool = False, + disable_ds_ckpt: bool = False, + ) -> None: + super().__init__() + self.strategy = strategy + self.epochs = max_epochs + self.max_norm = max_norm + self.model = model + self.train_dataloader = train_dataloader + self.eval_dataloader = eval_dataloader + self.ref_model = ref_model + self.scheduler = scheduler + self.optimizer = optim + self.tokenizer = tokenizer + self.args = strategy.args + self.save_hf_ckpt = save_hf_ckpt + self.disable_ds_ckpt = disable_ds_ckpt + + self.beta = beta + self.loss_fn = DPOLoss(self.beta, self.args.label_smoothing, self.args.ipo) + + # Mixtral 8*7b + self.aux_loss = self.args.aux_loss_coef > 1e-8 + + # NLL loss + self.nll_loss = self.args.nll_loss_coef > 1e-8 + + # packing samples + self.packing_samples = strategy.args.packing_samples + + # wandb/tensorboard setting + self._wandb = None + self._tensorboard = None + if self.strategy.args.use_wandb and self.strategy.is_rank_0(): + import wandb + + self._wandb = wandb + if not wandb.api.api_key: + wandb.login(key=strategy.args.use_wandb) + wandb.init( + entity=strategy.args.wandb_org, + project=strategy.args.wandb_project, + group=strategy.args.wandb_group, + name=strategy.args.wandb_run_name, + config=strategy.args.__dict__, + reinit=True, + ) + + wandb.define_metric("train/global_step") + wandb.define_metric("train/*", step_metric="train/global_step", step_sync=True) + wandb.define_metric("eval/global_step") + wandb.define_metric("eval/*", step_metric="eval/global_step", step_sync=True) + + # Initialize TensorBoard writer if wandb is not available + if self.strategy.args.use_tensorboard and self._wandb is None and self.strategy.is_rank_0(): + from torch.utils.tensorboard import SummaryWriter + + os.makedirs(self.strategy.args.use_tensorboard, exist_ok=True) + log_dir = os.path.join(self.strategy.args.use_tensorboard, strategy.args.wandb_run_name) + self._tensorboard = SummaryWriter(log_dir=log_dir) + + def fit(self, args, consumed_samples=0, num_update_steps_per_epoch=None): + # get eval and save steps + if args.eval_steps == -1: + args.eval_steps = num_update_steps_per_epoch # Evaluate once per epoch + if args.save_steps == -1: + args.save_steps = float("inf") # do not save ckpt + + # Restore step and start_epoch + step = consumed_samples // args.train_batch_size * self.strategy.accumulated_gradient + 1 + start_epoch = consumed_samples // args.train_batch_size // num_update_steps_per_epoch + consumed_samples = consumed_samples % (num_update_steps_per_epoch * args.train_batch_size) + + epoch_bar = tqdm( + range(start_epoch, self.epochs), + desc="Train epoch", + disable=not self.strategy.is_rank_0(), + ) + acc_sum = 0 + loss_sum = 0 + for epoch in range(start_epoch, self.epochs): + if isinstance(self.train_dataloader.sampler, DistributedSampler): + self.train_dataloader.sampler.set_epoch( + epoch, consumed_samples=0 if epoch > start_epoch else consumed_samples + ) + + step_bar = tqdm( + range(self.train_dataloader.__len__()), + desc="Train step of epoch %d" % epoch, + disable=not self.strategy.is_rank_0(), + ) + + self.model.train() + self.ref_model.eval() + # train + for data in self.train_dataloader: + if not self.packing_samples: + chosen_ids, c_mask, reject_ids, r_mask, prompt_id_lens = data + chosen_ids = chosen_ids.squeeze(1).to(torch.cuda.current_device()) + c_mask = c_mask.squeeze(1).to(torch.cuda.current_device()) + reject_ids = reject_ids.squeeze(1).to(torch.cuda.current_device()) + r_mask = r_mask.squeeze(1).to(torch.cuda.current_device()) + + chosen_logps, rejected_logps, aux_loss, nll_loss = self.concatenated_forward( + self.model, chosen_ids, c_mask, reject_ids, r_mask, prompt_id_lens + ) + with torch.no_grad(): + reference_chosen_logps, reference_rejected_logps, _, _ = self.concatenated_forward( + self.ref_model, chosen_ids, c_mask, reject_ids, r_mask, prompt_id_lens + ) + else: + packed_input_ids, packed_attention_masks, packed_seq_lens, prompt_id_lens = data + packed_input_ids, packed_attention_masks = packed_input_ids.to( + torch.cuda.current_device() + ), packed_attention_masks.to(torch.cuda.current_device()) + chosen_logps, rejected_logps, aux_loss, nll_loss = self.packed_samples_forward( + self.model, packed_input_ids, packed_attention_masks, packed_seq_lens, prompt_id_lens + ) + with torch.no_grad(): + reference_chosen_logps, reference_rejected_logps, _, _ = self.packed_samples_forward( + self.ref_model, packed_input_ids, packed_attention_masks, packed_seq_lens, prompt_id_lens + ) + + # loss function + preference_loss, chosen_reward, reject_reward = self.loss_fn( + chosen_logps, rejected_logps, reference_chosen_logps, reference_rejected_logps + ) + # mixtral + if not self.aux_loss: + aux_loss = 0 + # nll loss + if not self.nll_loss: + nll_loss = 0 + + loss = preference_loss + aux_loss * self.args.aux_loss_coef + nll_loss * self.args.nll_loss_coef + self.strategy.backward(loss, self.model, self.optimizer) + self.strategy.optimizer_step(self.optimizer, self.model, self.scheduler) + + acc = (chosen_reward > reject_reward).float().mean().item() + acc_sum += acc + loss_sum += preference_loss.item() + # dpo logs + logs_dict = { + "loss": preference_loss.item(), + "acc": acc, + "chosen_reward": chosen_reward.mean().item(), + "reject_reward": reject_reward.mean().item(), + "lr": self.scheduler.get_last_lr()[0], + } + if self.nll_loss: + logs_dict["nll_loss"] = nll_loss.item() + # step bar + logs_dict = self.strategy.all_reduce(logs_dict) + step_bar.set_postfix(logs_dict) + step_bar.update() + + # logs/checkpoints/evaluation + if step % self.strategy.accumulated_gradient == 0: + logs_dict["loss_mean"] = loss_sum / self.strategy.accumulated_gradient + logs_dict["acc_mean"] = acc_sum / self.strategy.accumulated_gradient + loss_sum = 0 + acc_sum = 0 + global_step = step // self.strategy.accumulated_gradient + client_states = {"consumed_samples": global_step * args.train_batch_size} + self.save_logs_and_checkpoints(args, global_step, step_bar, logs_dict, client_states) + + step += 1 + + epoch_bar.update() + + if self._wandb is not None and self.strategy.is_rank_0(): + self._wandb.finish() + if self._tensorboard is not None and self.strategy.is_rank_0(): + self._tensorboard.close() + + # logs/checkpoints/evaluate + def save_logs_and_checkpoints(self, args, global_step, step_bar, logs_dict={}, client_states={}): + # logs + if global_step % args.logging_steps == 0: + # wandb + if self._wandb is not None and self.strategy.is_rank_0(): + logs = {"train/%s" % k: v for k, v in {**logs_dict, "global_step": global_step}.items()} + self._wandb.log(logs) + # TensorBoard + elif self._tensorboard is not None and self.strategy.is_rank_0(): + for k, v in logs_dict.items(): + self._tensorboard.add_scalar(f"train/{k}", v, global_step) + + # eval + if global_step % args.eval_steps == 0: + # do eval when len(dataloader) > 0, avoid zero division in eval. + if len(self.eval_dataloader) > 0: + self.evaluate(self.eval_dataloader, global_step) + + # save ckpt + # TODO: save best model on dev, use loss/perplexity on whole dev dataset as metric + if global_step % args.save_steps == 0: + tag = f"global_step{global_step}" + if not self.disable_ds_ckpt: + self.strategy.save_ckpt( + self.model.model, args.ckpt_path, tag, args.max_ckpt_num, args.max_ckpt_mem, client_states + ) + if self.save_hf_ckpt: + save_path = os.path.join(args.ckpt_path, f"{tag}_hf") + self.strategy.save_model(self.model, self.tokenizer, save_path) + + def evaluate(self, eval_dataloader, steps=0): + self.model.eval() + with torch.no_grad(): + step_bar = tqdm( + range(eval_dataloader.__len__()), + desc="Eval stage of global_step %d" % steps, + disable=not self.strategy.is_rank_0(), + ) + acc_sum = 0 + loss_sum = 0 + times = 0 + for data in eval_dataloader: + if not self.packing_samples: + chosen_ids, c_mask, reject_ids, r_mask, prompt_id_lens = data + chosen_ids = chosen_ids.squeeze(1).to(torch.cuda.current_device()) + c_mask = c_mask.squeeze(1).to(torch.cuda.current_device()) + reject_ids = reject_ids.squeeze(1).to(torch.cuda.current_device()) + r_mask = r_mask.squeeze(1).to(torch.cuda.current_device()) + + chosen_logps, rejected_logps, aux_loss, _ = self.concatenated_forward( + self.model, chosen_ids, c_mask, reject_ids, r_mask, prompt_id_lens + ) + with torch.no_grad(): + reference_chosen_logps, reference_rejected_logps, _, _ = self.concatenated_forward( + self.ref_model, chosen_ids, c_mask, reject_ids, r_mask, prompt_id_lens + ) + else: + packed_input_ids, packed_attention_masks, packed_seq_lens, prompt_id_lens = data + packed_input_ids, packed_attention_masks = packed_input_ids.to( + torch.cuda.current_device() + ), packed_attention_masks.to(torch.cuda.current_device()) + chosen_logps, rejected_logps, aux_loss, _ = self.packed_samples_forward( + self.model, packed_input_ids, packed_attention_masks, packed_seq_lens, prompt_id_lens + ) + with torch.no_grad(): + reference_chosen_logps, reference_rejected_logps, _, _ = self.packed_samples_forward( + self.ref_model, packed_input_ids, packed_attention_masks, packed_seq_lens, prompt_id_lens + ) + + loss, chosen_reward, reject_reward = self.loss_fn( + chosen_logps, rejected_logps, reference_chosen_logps, reference_rejected_logps + ) + acc_sum += (chosen_reward > reject_reward).float().mean().item() + loss_sum += loss.item() + times += 1 + step_bar.update() + + logs = { + "eval_loss": loss_sum / times, + "acc_mean": acc_sum / times, + } + logs = self.strategy.all_reduce(logs) + step_bar.set_postfix(logs) + + if self.strategy.is_rank_0(): + if self._wandb is not None: + logs = {"eval/%s" % k: v for k, v in {**logs, "global_step": steps}.items()} + self._wandb.log(logs) + elif self._tensorboard is not None: + for k, v in logs.items(): + self._tensorboard.add_scalar(f"eval/{k}", v, steps) + self.model.train() # reset model state + + def concatenated_forward(self, model, chosen_ids, c_mask, reject_ids, r_mask, prompt_id_lens): + """Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together. + + We do this to avoid doing two forward passes, because it's faster for FSDP. + """ + input_ids, att_masks, prompt_id_lens = self.concatenated_inputs( + chosen_ids, c_mask, reject_ids, r_mask, prompt_id_lens + ) + output = model(input_ids, attention_mask=att_masks, return_output=True) + all_logits = output["logits"] + all_logps_sum, all_logps_mean = self._get_batch_logps( + all_logits, input_ids, att_masks, prompt_id_lens, average_log_prob=False + ) + chosen_logps = all_logps_sum[: chosen_ids.shape[0]] + rejected_logps = all_logps_sum[chosen_ids.shape[0] :] + aux_loss = output.aux_loss if "aux_loss" in output else [] + return chosen_logps, rejected_logps, aux_loss, -all_logps_mean[: chosen_ids.shape[0]].mean() + + def concatenated_inputs(self, chosen_ids, c_mask, reject_ids, r_mask, prompt_id_lens): + """Concatenate the chosen and rejected inputs into a single tensor. + + Args: + batch: A batch of data. Must contain the keys 'chosen_input_ids' and 'rejected_input_ids', which are tensors of shape (batch_size, sequence_length). + + Returns: + A dictionary containing the concatenated inputs under the key 'concatenated_input_ids'. + """ + + def pad_to_length(tensor, length, pad_value, dim=-1): + if tensor.size(dim) >= length: + return tensor + else: + pad_size = list(tensor.shape) + pad_size[dim] = length - tensor.size(dim) + return torch.cat( + [tensor, pad_value * torch.ones(*pad_size, dtype=tensor.dtype, device=tensor.device)], dim=dim + ) + + max_length = max(chosen_ids.shape[1], reject_ids.shape[1]) + inputs_ids = torch.cat( + ( + pad_to_length(chosen_ids, max_length, self.tokenizer.pad_token_id), + pad_to_length(reject_ids, max_length, self.tokenizer.pad_token_id), + ), + dim=0, + ) + max_length = max(c_mask.shape[1], r_mask.shape[1]) + att_masks = torch.cat((pad_to_length(c_mask, max_length, 0), pad_to_length(r_mask, max_length, 0)), dim=0) + return inputs_ids, att_masks, prompt_id_lens * 2 + + def _get_batch_logps( + self, + logits: torch.FloatTensor, + labels: torch.LongTensor, + attention_mask, + prompt_id_lens, + average_log_prob: bool = False, + ) -> torch.FloatTensor: + """Compute the log probabilities of the given labels under the given logits. + + Args: + logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size) + labels: Labels for which to compute the log probabilities. Label tokens with a value of -100 are ignored. Shape: (batch_size, sequence_length) + average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens. + + Returns: + A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits. + """ + assert average_log_prob == False + assert logits.shape[:-1] == labels.shape + + labels = labels[:, 1:].clone() + logits = logits[:, :-1, :] + + loss_masks = attention_mask.clone().bool() + # mask prompts + for mask, source_len in zip(loss_masks, prompt_id_lens): + mask[:source_len] = False + loss_masks = loss_masks[:, 1:] + + # dummy token; we'll ignore the losses on these tokens later + labels[loss_masks == False] = 0 + per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2) + + logprobs_sums = (per_token_logps * loss_masks).sum(-1) + logprobs_means = (per_token_logps * loss_masks).sum(-1) / loss_masks.sum(-1) + return logprobs_sums, logprobs_means + + def packed_samples_forward(self, model, packed_input_ids, packed_attention_masks, packed_seq_lens, prompt_id_lens): + output = model( + packed_input_ids, + attention_mask=packed_attention_masks, + return_output=True, + ring_attn_group=self.strategy.ring_attn_group, + packed_seq_lens=packed_seq_lens, + ) + all_logits = output["logits"] + all_logps_sum, all_logps_mean = self._packed_get_batch_logps( + all_logits, + packed_input_ids, + packed_attention_masks, + prompt_id_lens * 2, + packed_seq_lens, + average_log_prob=False, + ) + chosen_logps = all_logps_sum[: len(packed_seq_lens) // 2] + rejected_logps = all_logps_sum[len(packed_seq_lens) // 2 :] + aux_loss = output.aux_loss if "aux_loss" in output else [] + return chosen_logps, rejected_logps, aux_loss, -all_logps_mean[: len(packed_seq_lens) // 2].mean() + + def _packed_get_batch_logps( + self, + logits: torch.FloatTensor, + labels: torch.LongTensor, + attention_mask, + prompt_id_lens, + packed_seq_lens, + average_log_prob: bool = False, + ) -> torch.FloatTensor: + assert average_log_prob == False + + if self.strategy.ring_attn_group is None: + assert logits.shape[:-1] == labels.shape + labels = labels[:, 1:] + logits = logits[:, :-1, :] + per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2) + else: + rank = self.strategy.ring_attn_rank + total_seq_len = labels.numel() + local_seq_len = total_seq_len // self.strategy.ring_attn_size + local_slice = slice(rank * local_seq_len + 1, (rank + 1) * local_seq_len + 1) + local_label = labels[:, local_slice] + if rank == self.strategy.ring_attn_size - 1: + # add a dummy label to the last logit + local_label = F.pad(local_label, (0, 1), value=0) + local_per_token_logps = torch.gather( + logits.log_softmax(-1), dim=2, index=local_label.unsqueeze(2) + ).squeeze(2) + # we may not need to all_gather the entire tensor, but it's easier to implement. + # use the flash_attn all_gather so that the all_gather has correct backward. + per_token_logps = all_gather(local_per_token_logps, self.strategy.ring_attn_group).reshape((1, -1)) + per_token_logps = per_token_logps[:, :-1] + + loss_masks = attention_mask.clone().bool() + + index = 0 + for i, seq_len in enumerate(packed_seq_lens): + loss_masks[0, index : index + prompt_id_lens[i]] = False + index = index + seq_len + + loss_masks = loss_masks[:, 1:] + + logprobs_sums = [] + logprobs_means = [] + index = 0 + for i, seq_len in enumerate(packed_seq_lens): + seq = per_token_logps[0, index : index + seq_len - 1] + mask = loss_masks[0, index : index + seq_len - 1] + logprobs_sums.append((seq * mask).sum()) + logprobs_means.append((seq * mask).sum() / mask.sum()) + index = index + seq_len + + return torch.stack(logprobs_sums), torch.stack(logprobs_means) diff --git a/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/trainer/kd_trainer.py b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/trainer/kd_trainer.py new file mode 100644 index 0000000000..a6d2244836 --- /dev/null +++ b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/trainer/kd_trainer.py @@ -0,0 +1,249 @@ +import os +from abc import ABC + +import torch +from torch.optim import Optimizer +from tqdm import tqdm + +from openrlhf.models import GPTLMLoss, KDLoss +from openrlhf.utils.distributed_sampler import DistributedSampler + + +class KDTrainer(ABC): + """ + Trainer for Knowledge Distillation. + + Args: + model (torch.nn.Module): The model to be trained. + strategy (Strategy): The training strategy to be applied. + optim (Optimizer): The optimizer for model training. + train_dataloader (DataLoader): The dataloader for the training dataset. + eval_dataloader (DataLoader): The dataloader for the evaluation dataset. + scheduler (Scheduler): The learning rate scheduler to adjust training rates. + max_norm (float, defaults to 1): Maximum gradient norm for clipping to prevent exploding gradients. + pretrain_mode (bool, defaults to False): Flag to indicate if the trainer is in pre-training mode. + batch_size (int, defaults to 1): Batch size for training. + max_epochs (int, defaults to 2): The maximum number of training epochs. + tokenizer (Tokenizer, optional): The tokenizer for processing input data. + """ + + def __init__( + self, + model, + teacher_model, + strategy, + optim: Optimizer, + train_dataloader, + eval_dataloader, + scheduler, + max_norm: float = 1, + pretrain_mode: bool = False, + batch_size: int = 1, + max_epochs: int = 2, + tokenizer=None, + ) -> None: + super().__init__() + self.strategy = strategy + self.epochs = max_epochs + self.batch_size = batch_size + self.max_norm = max_norm + self.train_dataloader = train_dataloader + self.eval_dataloader = eval_dataloader + self.scheduler = scheduler + self.pretrain_mode = pretrain_mode + self.model = model + self.teacher_model = teacher_model + self.tokenizer = tokenizer + self.optimizer = optim + self.args = strategy.args + + self.loss_fn = GPTLMLoss() + self.kd_loss = KDLoss() + + # wandb/tensorboard setting + self._wandb = None + self._tensorboard = None + if self.strategy.args.use_wandb and self.strategy.is_rank_0(): + import wandb + + self._wandb = wandb + if not wandb.api.api_key: + wandb.login(key=strategy.args.use_wandb) + wandb.init( + entity=strategy.args.wandb_org, + project=strategy.args.wandb_project, + group=strategy.args.wandb_group, + name=strategy.args.wandb_run_name, + config=strategy.args.__dict__, + reinit=True, + ) + + wandb.define_metric("train/global_step") + wandb.define_metric("train/*", step_metric="train/global_step", step_sync=True) + wandb.define_metric("eval/global_step") + wandb.define_metric("eval/*", step_metric="eval/global_step", step_sync=True) + + # Initialize TensorBoard writer if wandb is not available + if self.strategy.args.use_tensorboard and self._wandb is None and self.strategy.is_rank_0(): + from torch.utils.tensorboard import SummaryWriter + + os.makedirs(self.strategy.args.use_tensorboard, exist_ok=True) + log_dir = os.path.join(self.strategy.args.use_tensorboard, strategy.args.wandb_run_name) + self._tensorboard = SummaryWriter(log_dir=log_dir) + + def fit(self, args, consumed_samples=0, num_update_steps_per_epoch=None): + # get eval and save steps + if args.eval_steps == -1: + args.eval_steps = num_update_steps_per_epoch # Evaluate once per epoch + if args.save_steps == -1: + args.save_steps = float("inf") # do not save ckpt + + # Restore step and start_epoch + step = consumed_samples // args.train_batch_size * self.strategy.accumulated_gradient + 1 + start_epoch = consumed_samples // args.train_batch_size // num_update_steps_per_epoch + consumed_samples = consumed_samples % (num_update_steps_per_epoch * args.train_batch_size) + + epoch_bar = tqdm( + range(start_epoch, self.epochs), + desc="Train epoch", + disable=not self.strategy.is_rank_0(), + ) + loss_sum = 0 + for epoch in range(start_epoch, self.epochs): + if isinstance(self.train_dataloader.sampler, DistributedSampler): + self.train_dataloader.sampler.set_epoch( + epoch, consumed_samples=0 if epoch > start_epoch else consumed_samples + ) + + step_bar = tqdm( + range(self.train_dataloader.__len__()), + desc="Train step of epoch %d" % epoch, + disable=not self.strategy.is_rank_0(), + ) + + # train + self.model.train() + self.teacher_model.eval() + for prompts_id_len, inputs, attention_masks, _ in self.train_dataloader: + inputs = inputs.squeeze(1).to(torch.cuda.current_device()) + attention_mask = attention_masks.squeeze(1).to(torch.cuda.current_device()) + output = self.model(inputs, attention_mask=attention_mask, return_output=True) + + # loss function + labels = torch.where( + attention_mask.bool(), + inputs, + self.loss_fn.IGNORE_INDEX, + ) + + if not self.pretrain_mode: + for label, source_len in zip(labels, prompts_id_len): + label[:source_len] = self.loss_fn.IGNORE_INDEX + + gpt_loss = self.loss_fn(output.logits, labels) + + with torch.no_grad(): + teacher_logits = self.teacher_model(inputs, attention_mask=attention_mask, return_output=True)[ + "logits" + ] + distil_loss = self.kd_loss(output.logits, teacher_logits, labels) + + loss = gpt_loss * (1 - self.args.kd_coef) + distil_loss * self.args.kd_coef + self.strategy.backward(loss, self.model, self.optimizer) + self.strategy.optimizer_step(self.optimizer, self.model, self.scheduler) + + loss_sum += gpt_loss.item() + logs_dict = { + "gpt_loss": gpt_loss.item(), + "distil_loss": distil_loss.item(), + "lr": self.scheduler.get_last_lr()[0], + } + # step bar + logs_dict = self.strategy.all_reduce(logs_dict) + step_bar.set_postfix(logs_dict) + step_bar.update() + + # logs/checkpoints/evaluation + if step % self.strategy.accumulated_gradient == 0: + logs_dict["loss_mean"] = loss_sum / self.strategy.accumulated_gradient + loss_sum = 0 + global_step = step // self.strategy.accumulated_gradient + client_states = {"consumed_samples": global_step * args.train_batch_size} + self.save_logs_and_checkpoints(args, global_step, step_bar, logs_dict, client_states) + + step += 1 + + epoch_bar.update() + + if self._wandb is not None and self.strategy.is_rank_0(): + self._wandb.finish() + if self._tensorboard is not None and self.strategy.is_rank_0(): + self._tensorboard.close() + + # logs/checkpoints/evaluation + def save_logs_and_checkpoints(self, args, global_step, step_bar, logs_dict={}, client_states={}): + if global_step % args.logging_steps == 0: + # wandb + if self._wandb is not None and self.strategy.is_rank_0(): + logs = {"train/%s" % k: v for k, v in {**logs_dict, "global_step": global_step}.items()} + self._wandb.log(logs) + # TensorBoard + elif self._tensorboard is not None and self.strategy.is_rank_0(): + for k, v in logs_dict.items(): + self._tensorboard.add_scalar(f"train/{k}", v, global_step) + + # eval + if global_step % args.eval_steps == 0: + # do eval when len(dataloader) > 0, avoid zero division in eval. + if len(self.eval_dataloader) > 0: + self.evaluate(self.eval_dataloader, global_step) + # save ckpt + # TODO: save best model on dev, use loss/perplexity on whole dev dataset as metric + if global_step % args.save_steps == 0: + tag = f"global_step{global_step}" + self.strategy.save_ckpt( + self.model.model, args.ckpt_path, tag, args.max_ckpt_num, args.max_ckpt_mem, client_states + ) + + def evaluate(self, eval_dataloader, steps=0): + times = 0 + self.model.eval() + with torch.no_grad(): + loss_sum = 0 + step_bar = tqdm( + range(eval_dataloader.__len__()), + desc="Eval stage of steps %d" % steps, + disable=not self.strategy.is_rank_0(), + ) + + for prompts_id_len, inputs, attention_masks, _ in eval_dataloader: + inputs = inputs.squeeze(1).to(torch.cuda.current_device()) + attention_mask = attention_masks.squeeze(1).to(torch.cuda.current_device()) + logits = self.model(inputs, attention_mask=attention_mask, return_output=True)["logits"] + + labels = torch.where( + attention_mask.bool(), + inputs, + self.loss_fn.IGNORE_INDEX, + ) + if not self.pretrain_mode: + for label, source_len in zip(labels, prompts_id_len): + label[:source_len] = self.loss_fn.IGNORE_INDEX + loss = self.loss_fn(logits, labels) + + times += 1 + loss_sum += loss.item() + bar_dict = {"eval gpt_loss": loss_sum / times} + step_bar.update() + logs = self.strategy.all_reduce(bar_dict) + step_bar.set_postfix(logs) + + if self.strategy.is_rank_0(): + if self._wandb is not None: + logs = {"eval/%s" % k: v for k, v in {**logs, "global_step": steps}.items()} + self._wandb.log(logs) + elif self._tensorboard is not None: + for k, v in logs.items(): + self._tensorboard.add_scalar(f"eval/{k}", v, steps) + + self.model.train() # reset model state diff --git a/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/trainer/kto_trainer.py b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/trainer/kto_trainer.py new file mode 100644 index 0000000000..60af93d209 --- /dev/null +++ b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/trainer/kto_trainer.py @@ -0,0 +1,342 @@ +import os +from abc import ABC + +import torch +from torch.optim import Optimizer +from tqdm import tqdm + +from openrlhf.models import KTOLoss +from openrlhf.utils.distributed_sampler import DistributedSampler + + +class KTOTrainer(ABC): + """ + Trainer for KTO training. + + Args: + model (torch.nn.Module): The primary model to be trained. + ref_model (torch.nn.Module): The reference model for comparing and guiding preference. + strategy (Strategy): The strategy to use for training. + tokenizer (Tokenizer): The tokenizer for processing input data. + optim (Optimizer): The optimizer for training the model. + train_dataloader (DataLoader): The dataloader for the training dataset. + eval_dataloader (DataLoader): The dataloader for the evaluation dataset. + scheduler (Scheduler): The learning rate scheduler to control learning rate during training. + max_norm (float, defaults to 0.5): Maximum gradient norm for gradient clipping. + beta (float, defaults to 0.01): Coefficient for regularizing the preference loss. + max_epochs (int, defaults to 2): Maximum number of training epochs. + """ + + def __init__( + self, + model, + ref_model, + strategy, + tokenizer, + optim: Optimizer, + train_dataloader, + eval_dataloader, + scheduler, + max_norm=0.5, + beta=0.01, + max_epochs: int = 2, + ) -> None: + super().__init__() + self.strategy = strategy + self.epochs = max_epochs + self.max_norm = max_norm + self.model = model + self.train_dataloader = train_dataloader + self.eval_dataloader = eval_dataloader + self.ref_model = ref_model + self.scheduler = scheduler + self.optimizer = optim + self.tokenizer = tokenizer + self.args = strategy.args + + self.beta = beta + self.loss_fn = KTOLoss( + self.beta, + self.args.desirable_loss_weight, + self.args.undesirable_loss_weight, + self.strategy.world_size, + torch.cuda.current_device(), + ) + + # Mixtral 8*7b + self.aux_loss = self.args.aux_loss_coef > 1e-8 + + # wandb/tensorboard setting + self._wandb = None + self._tensorboard = None + if self.strategy.args.use_wandb and self.strategy.is_rank_0(): + import wandb + + self._wandb = wandb + if not wandb.api.api_key: + wandb.login(key=strategy.args.use_wandb) + wandb.init( + entity=strategy.args.wandb_org, + project=strategy.args.wandb_project, + group=strategy.args.wandb_group, + name=strategy.args.wandb_run_name, + config=strategy.args.__dict__, + reinit=True, + ) + + wandb.define_metric("train/global_step") + wandb.define_metric("train/*", step_metric="train/global_step", step_sync=True) + wandb.define_metric("eval/global_step") + wandb.define_metric("eval/*", step_metric="eval/global_step", step_sync=True) + + # Initialize TensorBoard writer if wandb is not available + if self.strategy.args.use_tensorboard and self._wandb is None and self.strategy.is_rank_0(): + from torch.utils.tensorboard import SummaryWriter + + os.makedirs(self.strategy.args.use_tensorboard, exist_ok=True) + log_dir = os.path.join(self.strategy.args.use_tensorboard, strategy.args.wandb_run_name) + self._tensorboard = SummaryWriter(log_dir=log_dir) + + def fit(self, args, consumed_samples=0, num_update_steps_per_epoch=None): + # get eval and save steps + if args.eval_steps == -1: + args.eval_steps = num_update_steps_per_epoch # Evaluate once per epoch + if args.save_steps == -1: + args.save_steps = float("inf") # do not save ckpt + + # Restore step and start_epoch + step = consumed_samples // args.train_batch_size * self.strategy.accumulated_gradient + 1 + start_epoch = consumed_samples // args.train_batch_size // num_update_steps_per_epoch + consumed_samples = consumed_samples % (num_update_steps_per_epoch * args.train_batch_size) + + epoch_bar = tqdm(range(start_epoch, self.epochs), desc="Train epoch", disable=not self.strategy.is_rank_0()) + loss_sum = 0 + for epoch in range(start_epoch, self.epochs): + if isinstance(self.train_dataloader.sampler, DistributedSampler): + self.train_dataloader.sampler.set_epoch( + epoch, consumed_samples=0 if epoch > start_epoch else consumed_samples + ) + + step_bar = tqdm( + range(self.train_dataloader.__len__()), + desc="Train step of epoch %d" % epoch, + disable=not self.strategy.is_rank_0(), + ) + + self.model.train() + self.ref_model.eval() + + # train + for input_ids, attention_mask, labels, prompt_ids_lens in self.train_dataloader: + input_ids = input_ids.squeeze(1).to(torch.cuda.current_device()) + attention_mask = attention_mask.squeeze(1).to(torch.cuda.current_device()) + + # make sure local batch size >= 2 (to pack unmatched examples) + policy_returns = self.compute_model_logps_with_KL( + self.model, input_ids, attention_mask, labels, prompt_ids_lens + ) + aux_loss = policy_returns[3] + + with torch.no_grad(): + ref_returns = self.compute_model_logps_with_KL( + self.ref_model, input_ids, attention_mask, labels, prompt_ids_lens + ) + + kto_loss, chosen_rewards, rejected_rewards, KL = self.loss_fn( + policy_returns[0], + policy_returns[1], + policy_returns[2], + ref_returns[0], + ref_returns[1], + ref_returns[2], + ) + + # mixtral + if not self.aux_loss: + aux_loss = 0 + + loss = kto_loss + aux_loss * self.args.aux_loss_coef + self.strategy.backward(loss, self.model, self.optimizer) + self.strategy.optimizer_step(self.optimizer, self.model, self.scheduler) + + loss_sum += loss.item() + logs_dict = { + "kto_loss": loss.item(), + "chosen_reward": chosen_rewards.mean().item() if len(chosen_rewards) != 0 else 0, + "reject_reward": rejected_rewards.mean().item() if len(rejected_rewards) != 0 else 0, + "lr": self.scheduler.get_last_lr()[0], + } + logs_dict["kl"] = KL.item() + logs_dict = self.strategy.all_reduce(logs_dict) + step_bar.set_postfix(logs_dict) + step_bar.update() + + # logs/checkpoints/evaluation + if step % self.strategy.accumulated_gradient == 0: + logs_dict["loss_mean"] = loss_sum / self.strategy.accumulated_gradient + loss_sum = 0 + global_step = step // self.strategy.accumulated_gradient + client_states = {"consumed_samples": global_step * args.train_batch_size} + self.save_logs_and_checkpoints(args, global_step, step_bar, logs_dict, client_states) + + step += 1 + epoch_bar.update() + + if self._wandb is not None and self.strategy.is_rank_0(): + self._wandb.finish() + if self._tensorboard is not None and self.strategy.is_rank_0(): + self._tensorboard.close() + + # logs/checkpoints/evaluate + def save_logs_and_checkpoints(self, args, global_step, step_bar, logs_dict={}, client_states={}): + # logs + if global_step % args.logging_steps == 0: + # wandb + if self._wandb is not None and self.strategy.is_rank_0(): + logs = {"train/%s" % k: v for k, v in {**logs_dict, "global_step": global_step}.items()} + self._wandb.log(logs) + # TensorBoard + elif self._tensorboard is not None and self.strategy.is_rank_0(): + for k, v in logs_dict.items(): + self._tensorboard.add_scalar(f"train/{k}", v, global_step) + + # eval + if global_step % args.eval_steps == 0: + # do eval when len(dataloader) > 0, avoid zero division in eval. + if len(self.eval_dataloader) > 0: + self.evaluate(global_step) + # save ckpt + # TODO: save best model on dev, use loss/perplexity on whole dev dataset as metric + if global_step % args.save_steps == 0: + tag = f"global_step{global_step}" + self.strategy.save_ckpt( + self.model.model, args.ckpt_path, tag, args.max_ckpt_num, args.max_ckpt_mem, client_states + ) + + def evaluate(self, steps=0): + self.model.eval() + with torch.no_grad(): + step_bar = tqdm( + range(self.eval_dataloader.__len__()), + desc="Eval stage of global_step %d" % steps, + disable=not self.strategy.is_rank_0(), + ) + + loss_sum = 0 + chosen_reward, reject_reward = 0, 0 + for input_ids, attention_mask, labels, prompt_ids_lens in self.eval_dataloader: + input_ids = input_ids.squeeze(1).to(torch.cuda.current_device()) + attention_mask = attention_mask.squeeze(1).to(torch.cuda.current_device()) + + # make sure local batch size >= 2 (to pack unmatched examples) + policy_returns = self.compute_model_logps_with_KL( + self.model, input_ids, attention_mask, labels, prompt_ids_lens + ) + aux_loss = policy_returns[3] + + with torch.no_grad(): + ref_returns = self.compute_model_logps_with_KL( + self.ref_model, input_ids, attention_mask, labels, prompt_ids_lens + ) + + kto_loss, chosen_rewards, rejected_rewards, KL = self.loss_fn( + policy_returns[0], + policy_returns[1], + policy_returns[2], + ref_returns[0], + ref_returns[1], + ref_returns[2], + ) + + chosen_reward += chosen_rewards.mean().item() + reject_reward += rejected_rewards.mean().item() + loss_sum += kto_loss.item() + step_bar.update() + + loss_mean = loss_sum / self.eval_dataloader.__len__() + chosen_reward = chosen_reward / self.eval_dataloader.__len__() + reject_reward = reject_reward / self.eval_dataloader.__len__() + + logs = {"eval_loss": loss_mean, "chosen_reward": chosen_reward, "reject_reward": reject_reward} + logs = self.strategy.all_reduce(logs) + step_bar.set_postfix(logs) + + if self.strategy.is_rank_0(): + if self._wandb is not None: + logs = {"eval/%s" % k: v for k, v in {**logs, "global_step": steps}.items()} + self._wandb.log(logs) + elif self._tensorboard is not None: + for k, v in logs.items(): + self._tensorboard.add_scalar(f"eval/{k}", v, steps) + self.model.train() + + def compute_model_logps_with_KL(self, model, input_ids, attention_mask, labels, prompt_id_lens): + """ + the front half is matched for spv, the latter half is unmatched for KL + """ + hsize = input_ids.shape[0] // 2 + + # front half + chosen_logps, reject_logps, aux_loss = self.compute_model_logps( + model, input_ids[:hsize], attention_mask[:hsize], labels[:hsize], prompt_id_lens[:hsize] + ) + + # latter half + output = model(input_ids[hsize:], attention_mask=attention_mask[hsize:], return_output=True) + all_logits = output["logits"] + KL_logps = self._get_batch_logps( + all_logits, + input_ids[hsize:], + attention_mask=attention_mask[hsize:], + average_log_prob=False, + prompt_id_lens=prompt_id_lens[hsize:], + ) + return chosen_logps, reject_logps, KL_logps, aux_loss + + def compute_model_logps(self, model, input_ids, attention_mask, labels, prompt_id_lens): + output = model(input_ids, attention_mask=attention_mask, return_output=True) + all_logits = output["logits"] + all_logps = self._get_batch_logps( + all_logits, input_ids, attention_mask=attention_mask, average_log_prob=False, prompt_id_lens=prompt_id_lens + ) + chosen_logps = all_logps[labels == 1] + reject_logps = all_logps[labels == 0] + aux_loss = output.aux_loss if "aux_loss" in output else [] + return chosen_logps, reject_logps, aux_loss + + def _get_batch_logps( + self, + logits: torch.FloatTensor, + labels: torch.LongTensor, + attention_mask: torch.LongTensor, + average_log_prob: bool = False, + prompt_id_lens=[], + ) -> torch.FloatTensor: + """Compute the log probabilities of the given labels under the given logits. + + Args: + logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size) + labels: Labels for which to compute the log probabilities. Label tokens with a value of -100 are ignored. Shape: (batch_size, sequence_length) + average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens. + + Returns: + A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits. + """ + assert logits.shape[:-1] == labels.shape + + labels = labels[:, 1:].clone() + logits = logits[:, :-1, :] + + loss_masks = attention_mask.clone().bool() + # mask prompts + for mask, source_len in zip(loss_masks, prompt_id_lens): + mask[:source_len] = False + loss_masks = loss_masks[:, 1:] + + # dummy token; we'll ignore the losses on these tokens later + labels[~loss_masks] = 0 + per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2) + + if average_log_prob: + return (per_token_logps * loss_masks).sum(-1) / loss_masks.sum(-1) + return (per_token_logps * loss_masks).sum(-1) diff --git a/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/trainer/ppo_trainer.py b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/trainer/ppo_trainer.py new file mode 100644 index 0000000000..3018ae04a8 --- /dev/null +++ b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/trainer/ppo_trainer.py @@ -0,0 +1,523 @@ +import os +import os.path +from abc import ABC +from typing import Any, Callable, Dict, List, Optional + +import torch +import torch.nn as nn +from torch.optim import Optimizer +from torch.utils.data import DataLoader +from tqdm import tqdm + +from openrlhf.models import Actor, GPTLMLoss, PolicyLoss, ValueLoss +from openrlhf.models.utils import masked_mean +from openrlhf.utils.distributed_sampler import DistributedSampler + +from .ppo_utils import AdaptiveKLController, Experience, FixedKLController, NaiveExperienceMaker, NaiveReplayBuffer + + +class PPOTrainer(ABC): + """ + Trainer for Proximal Policy Optimization (PPO) algorithm. + + Args: + strategy (Strategy): The training strategy to use. + actor (Actor): The actor model in the PPO algorithm. + critic (nn.Module): The critic model in the PPO algorithm. + reward_model (nn.Module): The reward model for calculating rewards in the RLHF setup. + initial_model (Actor): The initial model for reference logits to limit actor updates in RLHF. + ema_model (Actor): The exponential moving average model for stable training. + actor_optim (Optimizer): The optimizer for the actor model. + critic_optim (Optimizer): The optimizer for the critic model. + actor_scheduler (Scheduler): The learning rate scheduler for the actor. + critic_scheduler (Scheduler): The learning rate scheduler for the critic. + ema_beta (float, defaults to 0.992): EMA decay rate for model stability. + init_kl_coef (float, defaults to 0.001): Initial coefficient for KL divergence. + kl_target (float, optional): Target value for KL divergence. + kl_horizon (int, defaults to 10000): Horizon for KL annealing. + ptx_coef (float, defaults to 0): Coefficient for supervised loss from pre-trained data. + micro_train_batch_size (int, defaults to 8): Micro-batch size for actor training. + buffer_limit (int, defaults to 0): Maximum size of the replay buffer. + buffer_cpu_offload (bool, defaults to True): If True, offloads replay buffer to CPU. + eps_clip (float, defaults to 0.2): Clipping coefficient for policy loss. + value_clip (float, defaults to 0.2): Clipping coefficient for value function loss. + micro_rollout_batch_size (int, defaults to 8): Micro-batch size for generating rollouts. + gradient_checkpointing (bool, defaults to False): If True, enables gradient checkpointing. + max_epochs (int, defaults to 1): Number of epochs to train. + max_norm (float, defaults to 1.0): Maximum gradient norm for gradient clipping. + tokenizer (Callable, optional): Tokenizer for input data. + prompt_max_len (int, defaults to 128): Maximum length for prompts. + dataloader_pin_memory (bool, defaults to True): If True, pins memory in the data loader. + remote_rm_url (str, optional): URL for remote reward model API. + reward_fn (Callable, optional): Custom reward function for computing rewards. + save_hf_ckpt (bool): Whether to save huggingface-format model weight. + disable_ds_ckpt (bool): Whether not to save deepspeed-format model weight. (Deepspeed model weight is used for training recovery) + **generate_kwargs: Additional arguments for model generation. + """ + + def __init__( + self, + strategy, + actor: Actor, + critic: nn.Module, + reward_model: nn.Module, + initial_model: Actor, + ema_model: Actor, + actor_optim: Optimizer, + critic_optim: Optimizer, + actor_scheduler, + critic_scheduler, + ema_beta: float = 0.992, + init_kl_coef: float = 0.001, + kl_target: float = None, + kl_horizon: int = 10000, + ptx_coef: float = 0, + micro_train_batch_size: int = 8, + buffer_limit: int = 0, + buffer_cpu_offload: bool = True, + eps_clip: float = 0.2, + value_clip: float = 0.2, + micro_rollout_batch_size: int = 8, + gradient_checkpointing: bool = False, + max_epochs: int = 1, + max_norm: float = 1.0, + tokenizer: Optional[Callable[[Any], dict]] = None, + prompt_max_len: int = 128, + dataloader_pin_memory: bool = True, + remote_rm_url: str = None, + reward_fn: Callable[[List[torch.Tensor]], torch.Tensor] = None, + save_hf_ckpt: bool = False, + disable_ds_ckpt: bool = False, + **generate_kwargs, + ) -> None: + assert ( + not isinstance(reward_model, List) or len(reward_model) == 1 or reward_fn is not None + ), "reward_fn must be specified if using multiple reward models" + + super().__init__() + self.strategy = strategy + self.args = strategy.args + self.save_hf_ckpt = save_hf_ckpt + self.disable_ds_ckpt = disable_ds_ckpt + self.micro_rollout_batch_size = micro_rollout_batch_size + self.max_epochs = max_epochs + self.tokenizer = tokenizer + self.generate_kwargs = generate_kwargs + self.dataloader_pin_memory = dataloader_pin_memory + self.max_norm = max_norm + self.ptx_coef = ptx_coef + self.micro_train_batch_size = micro_train_batch_size + self.kl_target = kl_target + self.prompt_max_len = prompt_max_len + self.ema_beta = ema_beta + self.gradient_checkpointing = gradient_checkpointing + self.reward_fn = reward_fn + + self.actor = actor + self.critic = critic + self.reward_model = reward_model + self.remote_rm_url = remote_rm_url + self.initial_model = initial_model + self.ema_model = ema_model + self.actor_optim = actor_optim + self.critic_optim = critic_optim + self.actor_scheduler = actor_scheduler + self.critic_scheduler = critic_scheduler + + self.actor_loss_fn = PolicyLoss(eps_clip) + self.critic_loss_fn = ValueLoss(value_clip) + self.ptx_loss_fn = GPTLMLoss() + + self.freezing_actor_steps = getattr(self.args, "freezing_actor_steps", -1) + + # Mixtral 8x7b + self.aux_loss = self.args.aux_loss_coef > 1e-8 + + if self.kl_target: + self.kl_ctl = AdaptiveKLController(init_kl_coef, kl_target, kl_horizon) + else: + self.kl_ctl = FixedKLController(init_kl_coef) + + self.experience_maker = NaiveExperienceMaker( + actor, + critic, + reward_model, + initial_model, + tokenizer, + prompt_max_len, + self.kl_ctl, + strategy, + remote_rm_url, + reward_fn, + ) + packing_samples = getattr(self.args, "packing_samples", False) + self.replay_buffer = NaiveReplayBuffer( + micro_train_batch_size, buffer_limit, buffer_cpu_offload, packing_samples + ) + + # wandb/tensorboard setting + self._wandb = None + self._tensorboard = None + if self.strategy.args.use_wandb and self.strategy.is_rank_0(): + import wandb + + self._wandb = wandb + if not wandb.api.api_key: + wandb.login(key=strategy.args.use_wandb) + wandb.init( + entity=strategy.args.wandb_org, + project=strategy.args.wandb_project, + group=strategy.args.wandb_group, + name=strategy.args.wandb_run_name, + config=strategy.args.__dict__, + reinit=True, + ) + + wandb.define_metric("train/global_step") + wandb.define_metric("train/*", step_metric="train/global_step", step_sync=True) + wandb.define_metric("eval/epoch") + wandb.define_metric("eval/*", step_metric="eval/epoch", step_sync=True) + + # Initialize TensorBoard writer if wandb is not available + if self.strategy.args.use_tensorboard and self._wandb is None and self.strategy.is_rank_0(): + from torch.utils.tensorboard import SummaryWriter + + os.makedirs(self.strategy.args.use_tensorboard, exist_ok=True) + log_dir = os.path.join(self.strategy.args.use_tensorboard, strategy.args.wandb_run_name) + self._tensorboard = SummaryWriter(log_dir=log_dir) + + def fit( + self, + args, + prompts_dataloader, + pretrain_dataloader, + consumed_samples=0, + num_update_steps_per_episodes=1, + ) -> None: + num_rollouts_per_episodes = ( + num_update_steps_per_episodes + * args.train_batch_size + // args.max_epochs + // args.rollout_batch_size + // args.n_samples_per_prompt + ) + + # get eval and save steps + if args.eval_steps == -1: + args.eval_steps = num_rollouts_per_episodes # Evaluate once per epoch + if args.save_steps == -1: + args.save_steps = float("inf") # do not save ckpt + + self.prompts_dataloader = prompts_dataloader + self.pretrain_dataloader = pretrain_dataloader + + # Restore step and start_epoch + steps = consumed_samples // args.rollout_batch_size + 1 + start_episode = consumed_samples // args.rollout_batch_size // num_rollouts_per_episodes + consumed_samples = consumed_samples % (num_rollouts_per_episodes * args.rollout_batch_size) + + for episode in range(start_episode, args.num_episodes): + if isinstance(self.prompts_dataloader.sampler, DistributedSampler): + self.prompts_dataloader.sampler.set_epoch( + episode, consumed_samples=0 if episode > start_episode else consumed_samples + ) + pbar = tqdm( + range(self.prompts_dataloader.__len__()), + desc=f"Episode [{episode + 1}/{args.num_episodes}]", + disable=not self.strategy.is_rank_0(), + ) + + for rand_prompts in self.prompts_dataloader: + for i, experience in enumerate( + self.experience_maker.make_experience_list(rand_prompts, **self.generate_kwargs) + ): + if i == 0: + output = self.tokenizer.batch_decode( + experience.sequences[0].unsqueeze(0), skip_special_tokens=True + ) + self.strategy.print(output) + self.replay_buffer.append(experience) + + torch.cuda.empty_cache() + self.replay_buffer.normalize("advantages", self.strategy) + status = self.ppo_train(steps) + self.replay_buffer.clear() + torch.cuda.empty_cache() + + if "kl" in status: + self.kl_ctl.update(status["kl"], args.rollout_batch_size * args.n_samples_per_prompt) + pbar.set_postfix(status) + + # logs/checkpoints + client_states = {"consumed_samples": steps * args.rollout_batch_size} + self.save_logs_and_checkpoints(args, steps, pbar, status, client_states) + + pbar.update() + steps = steps + 1 + + if self._wandb is not None and self.strategy.is_rank_0(): + self._wandb.finish() + if self._tensorboard is not None and self.strategy.is_rank_0(): + self._tensorboard.close() + + def ppo_train(self, global_steps=0): + # replay buffer may be empty at first, we should rebuild at each training + dataloader = DataLoader( + self.replay_buffer, + batch_size=self.replay_buffer.sample_batch_size, + shuffle=True, + drop_last=True, + pin_memory=self.dataloader_pin_memory, + collate_fn=self.replay_buffer.collate_fn, + ) + device = torch.cuda.current_device() + + status_list = [] + status_mean = {} + for epoch in range(self.max_epochs): + pbar = tqdm( + dataloader, + desc=f"Train epoch [{epoch + 1}/{self.max_epochs}]", + disable=not self.strategy.is_rank_0(), + ) + for experience in pbar: + experience.to_device(device) + status = self.training_step(experience, global_steps) + + # for DP + # weighted mean for kl + if "kl" in status: + status["kl"] *= status["response_length"] + status = self.strategy.all_reduce(status) + status["kl"] /= status["response_length"] + + short_status = {} + + if "policy_loss" in status: + short_status = { + "pg": status["policy_loss"], + "rm": status["reward"], + "ret": status["return"], + "glen": status["response_length"], + "tlen": status["total_length"], + "kl": status["kl"], + "act_lr": status["actor_lr"], + } + + if "critic_loss" in status: + short_status["cri"] = status["critic_loss"] + short_status["vals"] = status["values"] + short_status["cri_lr"] = status["critic_lr"] + + if "ptx_loss" in status: + short_status["ptx"] = status["ptx_loss"] + + status_list.append(status) + pbar.set_postfix(short_status) + + if status_list: + status_mean = status_list[0] + for m in status_list[1:]: + for k, v in m.items(): + status_mean[k] += v + for k in status_mean.keys(): + status_mean[k] /= len(status_list) + return status_mean + + def training_step(self, experience: Experience, global_steps) -> Dict[str, float]: + status = {} + if global_steps > self.freezing_actor_steps: + status = self.training_step_actor(experience) + if self.critic is not None: + status.update(self.training_step_critic(experience)) + return status + + def training_step_actor(self, experience: Experience) -> Dict[str, float]: + self.actor.train() + + # TODO: this is a bad indicator to say that data is packed... + if isinstance(experience.sequences, list): + sequences = torch.cat(experience.sequences, dim=0).unsqueeze(0) + old_action_log_probs = torch.cat(experience.action_log_probs, dim=0).unsqueeze(0) + advantages = torch.cat(experience.advantages, dim=0).unsqueeze(0) + num_actions = [v.numel() for v in experience.advantages] + packed_seq_lens = [s.numel() for s in experience.sequences] + attention_mask = torch.cat( + [torch.full_like(s, i + 1) for i, s in enumerate(experience.sequences)], dim=0 + ).unsqueeze(0) + else: + sequences = experience.sequences + old_action_log_probs = experience.action_log_probs + advantages = experience.advantages + num_actions = experience.action_mask.size(1) + packed_seq_lens = None + attention_mask = experience.attention_mask + + # actor loss + action_log_probs, output = self.actor( + sequences, + num_actions, + attention_mask=attention_mask, + return_output=True, + packed_seq_lens=packed_seq_lens, + ) + + # loss function + actor_loss = self.actor_loss_fn( + action_log_probs, + old_action_log_probs, + advantages, + action_mask=experience.action_mask, + ) + # mixtral + if self.aux_loss: + aux_loss = output.aux_loss + else: + aux_loss = 0 + loss = actor_loss + aux_loss * self.args.aux_loss_coef + self.strategy.backward(loss, self.actor, self.actor_optim) + + # ptx loss + if self.pretrain_dataloader is not None: + data = next(self.pretrain_dataloader) + inputs = data[1].squeeze(1).to(torch.cuda.current_device()) + attention_mask = data[2].squeeze(1).to(torch.cuda.current_device()) + label = torch.where( + attention_mask.bool(), + inputs, + self.ptx_loss_fn.IGNORE_INDEX, + ) + + output = self.actor(inputs, attention_mask=attention_mask, return_output=True) + ptx_log_probs = output["logits"] + + # loss function + ptx_loss = self.ptx_loss_fn(ptx_log_probs, label) + # mixtral + if self.aux_loss: + aux_loss = output.aux_loss + else: + aux_loss = 0 + loss = ptx_loss + aux_loss * self.args.aux_loss_coef + self.strategy.backward(self.ptx_coef * loss, self.actor, self.actor_optim) + + self.strategy.optimizer_step(self.actor_optim, self.actor, self.actor_scheduler, name="actor") + if self.ema_model: + self.strategy.moving_average(self.actor, self.ema_model, self.ema_beta, "cpu") + + # status + status = {"policy_loss": actor_loss.item(), "actor_lr": self.actor_scheduler.get_last_lr()[0]} + if self.pretrain_dataloader is not None: + status["ptx_loss"] = ptx_loss.item() + for k, v in experience.info.items(): + if k == "kl": + status[k] = ( + (v * experience.info["response_length"]).sum() / experience.info["response_length"].sum() + ).item() + else: + status[k] = v.mean().item() + return status + + def training_step_critic(self, experience: Experience) -> Dict[str, float]: + self.critic.train() + + # TODO: this is a bad indicator to say that data is packed... + if isinstance(experience.sequences, list): + sequences = torch.cat(experience.sequences, dim=0).unsqueeze(0) + old_values = torch.cat(experience.values, dim=0).unsqueeze(0) + returns = torch.cat(experience.returns, dim=0).unsqueeze(0) + num_actions = [v.numel() for v in experience.advantages] + packed_seq_lens = [s.numel() for s in experience.sequences] + attention_mask = torch.cat( + [torch.full_like(s, i + 1) for i, s in enumerate(experience.sequences)], dim=0 + ).unsqueeze(0) + else: + sequences = experience.sequences + old_values = experience.values + returns = experience.returns + num_actions = experience.action_mask.size(1) + packed_seq_lens = None + attention_mask = experience.attention_mask + + # critic loss + values, output = self.critic( + sequences, + num_actions=num_actions, + attention_mask=attention_mask, + return_output=True, + packed_seq_lens=packed_seq_lens, + ) + # loss function + critic_loss = self.critic_loss_fn( + values, + old_values, + returns, + action_mask=experience.action_mask, + ) + # mixtral + if self.aux_loss: + aux_loss = output.aux_loss + else: + aux_loss = 0 + loss = critic_loss + aux_loss * self.args.aux_loss_coef + self.strategy.backward(loss, self.critic, self.critic_optim) + self.strategy.optimizer_step(self.critic_optim, self.critic, self.critic_scheduler, name="critic") + + # status + status = { + "critic_loss": critic_loss.item(), + "values": masked_mean(values, experience.action_mask).item(), + "critic_lr": self.critic_scheduler.get_last_lr()[0], + } + return status + + def save_logs_and_checkpoints(self, args, global_step, step_bar, logs_dict={}, client_states={}): + if global_step % args.logging_steps == 0: + # wandb + if self._wandb is not None and self.strategy.is_rank_0(): + logs = { + "train/%s" % k: v + for k, v in { + **logs_dict, + "global_step": global_step, + }.items() + } + if self.experience_maker.perf_stats is not None: + logs.update({f"perf/experience_maker/{k}": v for k, v in self.experience_maker.perf_stats.items()}) + self._wandb.log(logs) + # TensorBoard + elif self._tensorboard is not None and self.strategy.is_rank_0(): + for k, v in logs_dict.items(): + self._tensorboard.add_scalar(f"train/{k}", v, global_step) + if self.experience_maker.perf_stats is not None: + for k, v in self.experience_maker.perf_stats.items(): + self._tensorboard.add_scalar(f"perf/experience_maker/{k}", v, global_step) + + # TODO: Add evaluation mechanism for PPO + if global_step % args.eval_steps == 0: + # self.evaluate(self.eval_dataloader, global_step) + pass + # save ckpt + # TODO: save best model on dev, use loss/perplexity/others on whole dev dataset as metric + if global_step % args.save_steps == 0: + tag = f"global_step{global_step}" + self._save_checkpoint(args, tag, client_states) + + def _save_checkpoint(self, args, tag, client_states): + if not self.disable_ds_ckpt: + self.strategy.save_ckpt( + self.actor.model, + os.path.join(args.ckpt_path, "_actor"), + tag, + args.max_ckpt_num, + args.max_ckpt_mem, + client_states, + ) + if self.critic is not None: + self.strategy.save_ckpt( + self.critic, os.path.join(args.ckpt_path, "_critic"), tag, args.max_ckpt_num, args.max_ckpt_mem + ) + + if self.save_hf_ckpt: + save_path = os.path.join(args.ckpt_path, f"{tag}_hf") + self.strategy.save_model(self.actor, self.tokenizer, save_path) diff --git a/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/trainer/ppo_utils/__init__.py b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/trainer/ppo_utils/__init__.py new file mode 100644 index 0000000000..a0c6d473f8 --- /dev/null +++ b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/trainer/ppo_utils/__init__.py @@ -0,0 +1,12 @@ +from .experience_maker import Experience, NaiveExperienceMaker, RemoteExperienceMaker +from .kl_controller import AdaptiveKLController, FixedKLController +from .replay_buffer import NaiveReplayBuffer + +__all__ = [ + "Experience", + "NaiveExperienceMaker", + "RemoteExperienceMaker", + "AdaptiveKLController", + "FixedKLController", + "NaiveReplayBuffer", +] diff --git a/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/trainer/ppo_utils/experience_maker.py b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/trainer/ppo_utils/experience_maker.py new file mode 100644 index 0000000000..b42adab90e --- /dev/null +++ b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/trainer/ppo_utils/experience_maker.py @@ -0,0 +1,765 @@ +import time +from abc import ABC +from copy import deepcopy +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import ray +import torch +import torch.nn as nn +from tqdm import tqdm + +from openrlhf.models.actor import Actor +from openrlhf.models.utils import compute_approx_kl, compute_reward, masked_mean, unpacking_samples +from openrlhf.utils.logging_utils import init_logger +from openrlhf.utils.remote_rm_utils import remote_rm_fn, remote_rm_fn_ray + +logger = init_logger(__name__) + + +def to(tensor: Union[torch.Tensor, list[torch.Tensor]], device): + if isinstance(tensor, list): + return [to(t, device) for t in tensor] + return tensor.to(device) if isinstance(tensor, torch.Tensor) else tensor + + +def pin_memory(tensor: Union[torch.Tensor, list[torch.Tensor]]): + if isinstance(tensor, list): + return [pin_memory(t) for t in tensor] + return tensor.pin_memory() if isinstance(tensor, torch.Tensor) else tensor + + +@dataclass +class Experience: + """Experience is a batch of data. + These data should have the the sequence length and number of actions. + Left padding for sequences is applied. + + Shapes of each tensor: + sequences: (B, S) + action_log_probs: (B, A) + values: (B, A) + returns: (B, A) + advantages: (B, A) + attention_mask: (B, S) + action_mask: (B, A) + kl: (B, A) + + "A" is the number of actions. + """ + + sequences: torch.Tensor + action_log_probs: torch.Tensor + values: torch.Tensor + returns: Optional[torch.Tensor] + advantages: Optional[torch.Tensor] + attention_mask: Optional[torch.LongTensor] + action_mask: Optional[torch.BoolTensor] + info: Optional[dict] + kl: Optional[torch.Tensor] = None + + @torch.no_grad() + def to_device(self, device: torch.device): + self.sequences = to(self.sequences, device) + self.action_log_probs = to(self.action_log_probs, device) + self.returns = to(self.returns, device) + self.advantages = to(self.advantages, device) + self.values = to(self.values, device) + self.attention_mask = to(self.attention_mask, device) + self.action_mask = to(self.action_mask, device) + self.kl = to(self.kl, device) + self.info = {key: to(value, device) for key, value in self.info.items()} + return self + + def pin_memory(self): + self.sequences = pin_memory(self.sequences) + self.action_log_probs = pin_memory(self.action_log_probs) + self.returns = pin_memory(self.returns) + self.advantages = pin_memory(self.advantages) + self.values = pin_memory(self.values) + self.attention_mask = pin_memory(self.attention_mask) + self.action_mask = pin_memory(self.action_mask) + self.kl = pin_memory(self.kl) + self.info = {key: pin_memory(value) for key, value in self.info.items()} + return self + + +@dataclass +class Samples: + """Samples is a batch of data. + There can be 2 formats to store the samples, batched or packed. + The batched format means padding is applied to the sequences, while the packed format + will concatenate the prompt and response without padding. + + Shapes of each tensor, when 2 shapes are shown, the first one is for batched format + and the second one is for packed format: + sequences: (B, S) or (1, total_length), the tokens of both prompt and response. + attention_mask: (B, S) or (1, total_length), the attention mask for sequences. + action_mask: (B, A) or None, the action (response) mask to show which part of the + sequence is the response. When the samples are packed, this is None. + num_actions: int or (B,), the number of actions (tokens) in the response. + When the samples are not packed, we will use action_mask, so this is an int to + show the size of action_mask. Otherwise, this is a tensor to show the number of + actions for each sample. + packed_seq_lens: None or (B,), the length of each sample in the packed samples. + response_length: (B,), the number of tokens in the response. + total_length: (B,), the total number of tokens in the sequences. + """ + + sequences: torch.Tensor + attention_mask: Optional[torch.LongTensor] + action_mask: Optional[torch.BoolTensor] + num_actions: Union[int, torch.Tensor] + packed_seq_lens: Optional[torch.Tensor] + response_length: torch.Tensor + total_length: torch.Tensor + + +class NaiveExperienceMaker(ABC): + """ + Naive experience maker. + """ + + def __init__( + self, + actor: Actor, + critic: nn.Module, + reward_model: nn.Module, + initial_model: Actor, + tokenizer, + prompt_max_len: int, + kl_controller, + strategy=None, + remote_rm_url: str = None, + reward_fn=None, + ) -> None: + super().__init__() + self.actor = actor + self.critic = critic + self.reward_model = reward_model + self.remote_rm_url = remote_rm_url + self.initial_model = initial_model + self.tokenizer = tokenizer + self.prompt_max_len = prompt_max_len + self.kl_ctl = kl_controller + self.strategy = strategy + self.reward_fn = reward_fn + self.perf_stats = None + self.advantage_estimator = strategy.args.advantage_estimator + + # tokenizer + def tokenize_fn(self, texts, max_length, padding=True, device=None): + if not padding: + # when padding is False, return tokenized texts as list + return self.tokenizer( + texts, + add_special_tokens=False, + max_length=max_length, + truncation=True, + ) + batch = self.tokenizer( + texts, + return_tensors="pt", + add_special_tokens=False, + max_length=max_length, + padding=True, + truncation=True, + ) + return {k: v.to(device) for k, v in batch.items()} + + @torch.no_grad() + def make_experience_list(self, all_prompts: Union[str, List[str]], **generate_kwargs) -> List[Experience]: + """ + Make a list of experience with the micro_rollout_batch_size. + + This method will first calculate the response sequences and rewards for the given prompts. + Then, if we need certain processing for the rewards or do certain filtering, we can process the rollout as a whole. + After that, we will calculate the advantages and returns for each experience. + """ + args = self.strategy.args + # generate responses + samples_list = self.generate_samples(all_prompts, **generate_kwargs) + torch.distributed.barrier() + + experiences = [] + for samples in tqdm( + samples_list, + desc="make_experience", + disable=not self.strategy.is_rank_0(), + ): + experiences.append(self.make_experience(samples).to_device("cpu")) + + experiences, rewards = self.process_experiences(experiences) + + # calculate return and advantages + for experience, reward in zip(experiences, rewards): + experience = experience.to_device("cuda") + reward = reward.to(device="cuda") + num_actions = experience.info["num_actions"] + reward = compute_reward( + reward, + self.kl_ctl.value, + experience.kl, + action_mask=experience.action_mask, + num_actions=num_actions, + reward_clip_range=args.reward_clip_range, + ) + + if self.advantage_estimator == "gae": + experience.advantages, experience.returns = self.get_advantages_and_returns( + experience.values, + reward, + experience.action_mask, + generate_kwargs["gamma"], + generate_kwargs["lambd"], + ) + elif self.advantage_estimator in ["reinforce", "rloo"]: + experience.returns = self.get_cumulative_returns( + reward, + experience.action_mask, + generate_kwargs["gamma"], + ) + experience.advantages = deepcopy(experience.returns) + else: + raise Exception(f"Unkown advantage_estimator {self.advantage_estimator}") + + # calculate the return info. + if not getattr(self, "packing_samples", False): + return_sums = reward.sum(dim=-1) + else: + return_sums = torch.tensor( + [each_reward.sum() for each_reward in reward], device=torch.cuda.current_device() + ) + experience.info["return"] = return_sums + # remove unnecessary info + experience.kl = None + del experience.info["num_actions"] + experience.to_device("cpu") + return experiences + + @torch.no_grad() + def generate_samples(self, all_prompts: List[str], **generate_kwargs) -> List[Samples]: + """ + Generate samples and return in batches. + """ + assert not getattr(self, "packing_samples", False) + args = self.strategy.args + self.actor.eval() + # sample multiple response + all_prompts = sum([[prompt] * args.n_samples_per_prompt for prompt in all_prompts], []) + samples_list = [] + for i in range(0, len(all_prompts), args.micro_rollout_batch_size): + prompts = all_prompts[i : i + args.micro_rollout_batch_size] + inputs = self.tokenize_fn(prompts, self.prompt_max_len, device="cuda") + sequences, attention_mask, action_mask = self.actor.generate(**inputs, **generate_kwargs) + samples = Samples( + sequences=sequences, + attention_mask=attention_mask, + action_mask=action_mask, + num_actions=action_mask.size(1), + packed_seq_lens=None, + response_length=action_mask.float().sum(dim=-1), + total_length=attention_mask.float().sum(dim=-1), + ) + samples_list.append(samples) + return samples_list + + @torch.no_grad() + def make_experience(self, samples: Samples) -> Experience: + """ + Turn samples into experience by calculating logprobs, values, rewards, and kl divergence. + """ + self.actor.eval() + self.initial_model.eval() + if self.reward_model is not None: + self.reward_model.eval() + if self.critic is not None: + self.critic.eval() + + # extract values from samples + sequences = samples.sequences + attention_mask = samples.attention_mask + action_mask = samples.action_mask + num_actions = samples.num_actions + + # log probs + action_log_probs = self.actor(sequences, num_actions, attention_mask) + + # init log probs + base_action_log_probs = self.initial_model(sequences, num_actions, attention_mask) + + # values + if self.critic is not None: + value = self.critic(sequences, num_actions, attention_mask) + else: + value = None + + # rewards + if self.remote_rm_url is not None: + # remote RM + queries = self.tokenizer.batch_decode(sequences.cpu(), skip_special_tokens=False) + r = remote_rm_fn(self.remote_rm_url, queries=queries).to(device=action_log_probs.device) + else: + # local RM + r = self.reward_model(sequences, attention_mask) + + kl = compute_approx_kl( + action_log_probs, + base_action_log_probs, + action_mask=action_mask, + use_kl_estimator_k3=self.strategy.args.use_kl_estimator_k3, + ) + + info = { + "kl": masked_mean(kl, action_mask, dim=-1), + "reward": r, + "response_length": samples.response_length, + "total_length": samples.total_length, + "num_actions": num_actions, + } + # reset model state + self.actor.train() + if self.critic is not None: + self.critic.train() + + return Experience( + sequences, + action_log_probs, + value, + None, + None, + attention_mask, + action_mask, + info, + kl, + ) + + @torch.no_grad() + def process_experiences(self, experiences: List[Experience]) -> Tuple[List[Experience], List[torch.Tensor]]: + """ + Process experiences, this can be used to filter out some experiences or do some processing on the rewards. + + Output: + - experiences: List of Experience + - rewards: List of rewards + """ + args = self.strategy.args + # reward shaping for RLOO + if args.advantage_estimator == "rloo": + rewards = torch.cat([experience.info["reward"] for experience in experiences]) + rewards = rewards.reshape(-1, args.n_samples_per_prompt).to(device="cuda") + baseline = (rewards.sum(-1, keepdim=True) - rewards) / (args.n_samples_per_prompt - 1) + rewards = rewards - baseline + rewards = rewards.flatten().to(device="cpu").chunk(len(experiences)) + return experiences, rewards + # default rewards + return experiences, [experience.info["reward"] for experience in experiences] + + @torch.no_grad() + def get_advantages_and_returns( + self, + values: torch.Tensor, + rewards: torch.Tensor, + action_mask: torch.Tensor, + gamma: float, + lambd: float, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Function that computes advantages and returns from rewards and values. + Calculated as in the original PPO paper: https://arxiv.org/abs/1707.06347 + Note that rewards may include a KL divergence loss term. + + Advantages looks like this: + Adv1 = R1 + γ * λ * R2 + γ^2 * λ^2 * R3 + ... + - V1 + γ * (1 - λ) V2 + γ^2 * λ * (1 - λ) V3 + ... + + Returns looks like this: + Ret1 = R1 + γ * λ * R2 + γ^2 * λ^2 * R3 + ... + + γ * (1 - λ) V2 + γ^2 * λ * (1 - λ) V3 + ... + + Input: + - values: Tensor of shape (batch_size, response_size) + - rewards: Tensor of shape (batch_size, response_size) + + Output: + - advantages: Tensor of shape (batch_size, response_size) + - returns: Tensor of shape (batch_size, response_size) + """ + if isinstance(values, list): + # packing samples + # TODO: this is slow... + advantages = [] + returns = [] + for v, r in zip(values, rewards): + adv, ret = self.get_advantages_and_returns(v.unsqueeze(0), r.unsqueeze(0), action_mask, gamma, lambd) + advantages.append(adv.squeeze(0)) + returns.append(ret.squeeze(0)) + return advantages, returns + + lastgaelam = 0 + advantages_reversed = [] + response_length = rewards.size(1) + + # Mask invalid responses + if action_mask is not None: + values = action_mask * values + rewards = action_mask * rewards + + for t in reversed(range(response_length)): + nextvalues = values[:, t + 1] if t < response_length - 1 else 0.0 + delta = rewards[:, t] + gamma * nextvalues - values[:, t] + lastgaelam = delta + gamma * lambd * lastgaelam + advantages_reversed.append(lastgaelam) + advantages = torch.stack(advantages_reversed[::-1], dim=1) + returns = advantages + values + return advantages.detach(), returns + + @torch.no_grad() + def get_cumulative_returns( + self, + rewards: torch.Tensor, + action_mask: torch.Tensor, + gamma: float, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Function that computes advantages and returns from rewards using REINFORCE. + REINFORCE uses cumulative returns without the GAE (Generalized Advantage Estimation). + + Input: + - rewards: Tensor of shape (batch_size, response_size) + - action_mask: Tensor of shape (batch_size, response_size), binary mask + - gamma: discount factor + + Output: + - returns: Tensor of shape (batch_size, response_size) + """ + + if isinstance(rewards, list): + # packing samples + # TODO: this is slow... + returns = [] + for r in rewards: + ret = self.get_cumulative_returns(r.unsqueeze(0), action_mask, gamma) + returns.append(ret.squeeze(0)) + return returns + + response_length = rewards.size(1) + returns = torch.zeros_like(rewards) + cumulative_return = torch.zeros(rewards.size(0), device=rewards.device) + + # Mask invalid responses if action_mask is provided + if action_mask is not None: + rewards = action_mask * rewards + + # Calculate returns by accumulating discounted rewards + for t in reversed(range(response_length)): + cumulative_return = rewards[:, t] + gamma * cumulative_return + returns[:, t] = cumulative_return + + return returns + + +class RemoteExperienceMaker(NaiveExperienceMaker): + def __init__(self, *args, vllm_engines: List = None, packing_samples=False, **kwargs): + super().__init__(*args, **kwargs) + self.vllm_engines = vllm_engines + self.packing_samples = packing_samples + + @torch.no_grad() + def make_experience_list(self, all_prompts: Union[str, List[str]], **generate_kwargs) -> List[Experience]: + if self.strategy.args.perf: + self.perf_stats = { + "generate_time": 0, + "actor_value_rm_time": 0, + "wait_time": 0, + } + experiences = super().make_experience_list(all_prompts, **generate_kwargs) + if self.critic is not None: + for experience in experiences: + # send experience to critic + experience_cpu = deepcopy(experience) + experience_cpu.to_device("cpu") + self._ref = self.critic.append.remote(experience_cpu) + return experiences + + @torch.no_grad() + def generate_samples(self, all_prompts: List[str], **generate_kwargs) -> List[Samples]: + """ + Generate samples and return in batches. + + When not using vllm, we will fallback to the default implementation, + in which actor will be used to generate samples. + """ + if self.vllm_engines is None: + return super().generate_samples(all_prompts, **generate_kwargs) + + return self._generate_vllm(all_prompts, **generate_kwargs) + + @torch.no_grad() + def make_experience(self, samples: Samples) -> Experience: + """ + Turn samples into experience by calculating logprobs, values, rewards, and kl divergence. + """ + self.actor.eval() + device = torch.cuda.current_device() + + # extract values from samples + sequences = samples.sequences + attention_mask = samples.attention_mask + action_mask = samples.action_mask + num_actions = samples.num_actions + packed_seq_lens = samples.packed_seq_lens + + start = time.time() + sequences_cpu, attention_mask_cpu = ( + sequences.to("cpu"), + attention_mask.to("cpu"), + ) + + # init log probs + base_action_log_probs_ref = self.initial_model.forward.remote( + sequences_cpu, num_actions, attention_mask_cpu, packed_seq_lens=packed_seq_lens + ) + + # values + if self.critic is not None: + value_ref = self.critic.forward.remote( + sequences_cpu, num_actions, attention_mask_cpu, packed_seq_lens=packed_seq_lens + ) + # avoid CUDA OOM when colocate models + if self.strategy.args.colocate_critic_reward: + ray.get([value_ref]) + ray.get([self.critic.empty_cache.remote()]) + else: + value_ref = ray.put(None) + + if self.strategy.args.colocate_actor_ref: + ray.get([base_action_log_probs_ref]) + ray.get([self.initial_model.empty_cache.remote()]) + + # rewards + r_refs = [] + # support remote RM API with ray + if not self.remote_rm_url: + for rm in self.reward_model: + r_refs.append(rm.forward.remote(sequences_cpu, attention_mask_cpu, packed_seq_lens=packed_seq_lens)) + else: + # remote RM + if not self.packing_samples: + queries = self.tokenizer.batch_decode(sequences_cpu, skip_special_tokens=False) + else: + sequences_list = [] + offset = 0 + tokens_list = sequences_cpu.tolist()[0] + for length in packed_seq_lens: + sequences_list.append(tokens_list[offset : offset + length]) + offset += length + queries = self.tokenizer.batch_decode(sequences_list, skip_special_tokens=False) + + for rm in self.remote_rm_url: + r = remote_rm_fn_ray.remote(rm, queries=queries) + r_refs.append(r) + + # log probs + action_log_probs = self.actor(sequences, num_actions, attention_mask, packed_seq_lens=packed_seq_lens) + actor_value_rm_time = time.time() - start + + # wait initial/critic/reward model done + start = time.time() + ref_values = ray.get([base_action_log_probs_ref, value_ref] + r_refs) + wait_time = time.time() - start + + base_action_log_probs, value, rewards = ref_values[0], ref_values[1], ref_values[2:] + base_action_log_probs = base_action_log_probs.to(device) + if value is not None: + value = value.to(device) + rewards = [r.to(device) for r in rewards] + r = self.reward_fn(rewards) if len(rewards) > 0 else rewards[0] + + # avoid CUDA OOM when colocate models + if self.strategy.args.colocate_critic_reward and not self.remote_rm_url: + ray.get([self.reward_model[0].empty_cache.remote()]) + + if self.strategy.args.colocate_actor_ref: + torch.cuda.empty_cache() + + kl = compute_approx_kl( + action_log_probs, + base_action_log_probs, + action_mask=action_mask, + use_kl_estimator_k3=self.strategy.args.use_kl_estimator_k3, + ) + + if not self.packing_samples: + kl_mean = masked_mean(kl, action_mask, dim=-1) + else: + # convert tensor into list of tensors so that it's easier to manipulate + # within dataset. + sequences = unpacking_samples(sequences, packed_seq_lens) + attention_mask = None + action_log_probs = unpacking_samples(action_log_probs, num_actions) + if value is not None: + value = unpacking_samples(value, num_actions) + + kl = unpacking_samples(kl, num_actions) + kl_mean = torch.tensor([each_kl.mean() for each_kl in kl], device=device) + + info = { + "kl": kl_mean, + "reward": r, + "response_length": samples.response_length, + "total_length": samples.total_length, + "num_actions": num_actions, + } + + if self.strategy.args.perf: + self.perf_stats["actor_value_rm_time"] += actor_value_rm_time + self.perf_stats["wait_time"] += wait_time + + experience = Experience( + sequences, + action_log_probs, + value, + None, + None, + attention_mask, + action_mask, + info, + kl, + ) + + self.actor.train() # reset model state + return experience + + def _generate_vllm(self, all_prompts: List[str], **kwargs) -> List[Samples]: + from vllm import SamplingParams + + # round-robin load balance + rank = torch.distributed.get_rank() + world_size = torch.distributed.get_world_size() + + # Select LLM engines: assign each rank an engine, or cycle through engines if world_size < engine_count + if len(self.vllm_engines) <= world_size: + llms = [self.vllm_engines[rank % len(self.vllm_engines)]] + else: + llms = self.vllm_engines[rank::world_size] + + args = self.strategy.args + + sampling_params = SamplingParams( + temperature=kwargs.get("temperature", 1.0), + top_p=kwargs.get("top_p", 1.0), + top_k=kwargs.get("top_k", -1), + max_tokens=kwargs.get("max_new_tokens", 1024), + min_tokens=kwargs.get("min_new_tokens", 1), + skip_special_tokens=kwargs.get("skip_special_tokens", False), + include_stop_str_in_output=True, + ) + + # Expand prompt list based on the number of samples per prompt + all_prompts = sum([[prompt] * args.n_samples_per_prompt for prompt in all_prompts], []) + all_prompt_token_ids = self.tokenize_fn(all_prompts, self.prompt_max_len, padding=False)["input_ids"] + + # Distribute requests to engines and collect responses to outputs + all_output_refs = [] + batch_size = (len(all_prompt_token_ids) + len(llms) - 1) // len(llms) + for i, llm in enumerate(llms): + prompt_token_ids = all_prompt_token_ids[i * batch_size : (i + 1) * batch_size] + if prompt_token_ids: + all_output_refs.append( + llm.generate.remote(sampling_params=sampling_params, prompt_token_ids=prompt_token_ids) + ) + + # Retrieve and combine results from all outputs + all_outputs = sum(ray.get(all_output_refs), []) + + samples_list = [] + for i in range(0, len(all_outputs), args.micro_rollout_batch_size): + outputs = all_outputs[i : i + self.strategy.args.micro_rollout_batch_size] + if not self.packing_samples: + # NOTE: concat all outputs to following format: + # + # | [PAD] [PAD] token token token | token token [EOS] [PAD] | + # | token token token token token | token token [EOS] [PAD] | + # | [PAD] [PAD] [PAD] token token | token token token [EOS] | + # |<---------- prompt ----------->|<-------- answer ------->| + max_input_len, max_output_len = 0, 0 + for output in outputs: + max_input_len = max(max_input_len, len(output.prompt_token_ids)) + max_output_len = max(max_output_len, len(output.outputs[0].token_ids)) + + pad_token_id, eos_token_id = self.tokenizer.pad_token_id, self.tokenizer.eos_token_id + sequences = [] + for output in outputs: + # left padding input + input_len = len(output.prompt_token_ids) + input_ids = [pad_token_id] * (max_input_len - input_len) + list(output.prompt_token_ids) + + # right padding output + output_len = len(output.outputs[0].token_ids) + output_ids = list(output.outputs[0].token_ids) + [pad_token_id] * (max_output_len - output_len) + + # concat input and output + sequences.append(input_ids + output_ids) + + sequences = torch.tensor(sequences) + sequences, attention_mask, action_mask = self.actor.process_sequences( + sequences, max_input_len, eos_token_id, pad_token_id + ) + sequences = sequences.to("cuda") + attention_mask = attention_mask.to("cuda") + action_mask = action_mask.to("cuda") + samples_list.append( + Samples( + sequences=sequences, + attention_mask=attention_mask, + action_mask=action_mask, + num_actions=action_mask.size(1), + packed_seq_lens=None, + response_length=action_mask.float().sum(dim=-1), + total_length=attention_mask.float().sum(dim=-1), + ) + ) + else: + # NOTE: concat all outputs to following format: + # + # | token token token | token token [EOS] | token token token token token | token token [EOS] | token token | token token token [EOS] | + # |<--- prompt ----->|<---- answer ----->|<---------- prompt ----------->|<----- answer ---->|<- prompt -->|<-------- answer ------->| + pad_token_id, eos_token_id = self.tokenizer.pad_token_id, self.tokenizer.eos_token_id + sequences = [] + packed_seq_lens = [] + attention_mask = [] + num_actions = [] + for i, output in enumerate(outputs): + input_len = len(output.prompt_token_ids) + output_len = len(output.outputs[0].token_ids) + packed_seq_lens.append(input_len + output_len) + sequences.extend(output.prompt_token_ids + list(output.outputs[0].token_ids)) + attention_mask.extend([i + 1] * (input_len + output_len)) + + # current_action_mask = [0] * (input_len - 1) + [1] * output_len + [0] + # num_actions.append(max(1, sum(current_action_mask))) + num_actions.append(max(1, output_len)) + + sequences = torch.tensor(sequences, device="cuda").unsqueeze(0) + attention_mask = torch.tensor(attention_mask, device="cuda").unsqueeze(0) + action_mask = None + response_length = torch.tensor(num_actions, device="cuda", dtype=torch.float) + total_length = torch.tensor(packed_seq_lens, device="cuda", dtype=torch.float) + samples_list.append( + Samples( + sequences=sequences, + attention_mask=attention_mask, + action_mask=None, + num_actions=num_actions, + packed_seq_lens=packed_seq_lens, + response_length=response_length, + total_length=total_length, + ) + ) + return samples_list + + def flush(self): + "Ensure all experience has been send to critic" + if self.critic is not None: + ray.get(self._ref) + self._ref = None diff --git a/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/trainer/ppo_utils/kl_controller.py b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/trainer/ppo_utils/kl_controller.py new file mode 100644 index 0000000000..a6e6eeed84 --- /dev/null +++ b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/trainer/ppo_utils/kl_controller.py @@ -0,0 +1,29 @@ +import numpy as np + + +class AdaptiveKLController: + """ + Adaptive KL controller described in the paper: + https://arxiv.org/pdf/1909.08593.pdf + """ + + def __init__(self, init_kl_coef, target, horizon): + self.value = init_kl_coef + self.target = target + self.horizon = horizon + + def update(self, current, n_steps): + target = self.target + proportional_error = np.clip(current / target - 1, -0.2, 0.2) + mult = 1 + proportional_error * n_steps / self.horizon + self.value *= mult + + +class FixedKLController: + """Fixed KL controller.""" + + def __init__(self, kl_coef): + self.value = kl_coef + + def update(self, current, n_steps): + pass diff --git a/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/trainer/ppo_utils/replay_buffer.py b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/trainer/ppo_utils/replay_buffer.py new file mode 100644 index 0000000000..07ef55d0dd --- /dev/null +++ b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/trainer/ppo_utils/replay_buffer.py @@ -0,0 +1,237 @@ +import random +from abc import ABC +from dataclasses import dataclass +from typing import List, Optional + +import torch +import torch.nn.functional as F + + +from .experience_maker import Experience + + +@dataclass +class BufferItem: + """BufferItem is an item of experience data. + + Shapes of each tensor: + sequences: (S) + action_log_probs: (A) + values: (1) + returns: (1) + advantages: (1) + attention_mask: (S) + action_mask: (A) + + "A" is the number of actions. + """ + + sequences: torch.Tensor + action_log_probs: torch.Tensor + values: torch.Tensor + returns: torch.Tensor + advantages: torch.Tensor + attention_mask: Optional[torch.LongTensor] + action_mask: Optional[torch.BoolTensor] + info: Optional[dict] + + +def split_experience_batch(experience: Experience) -> List[BufferItem]: + batch_size = len(experience.sequences) + batch_kwargs = [{} for _ in range(batch_size)] + keys = ( + "sequences", + "action_log_probs", + "values", + "returns", + "advantages", + "attention_mask", + "action_mask", + ) + for key in keys: + value = getattr(experience, key) + if value is None: + for i in range(batch_size): + batch_kwargs[i][key] = None + continue + vals = value + if isinstance(vals, torch.Tensor): + vals = torch.unbind(vals) + assert batch_size == len(vals) + for i, v in enumerate(vals): + batch_kwargs[i][key] = v + + for i in range(batch_size): + batch_kwargs[i]["info"] = {} + for k, v in experience.info.items(): + vals = torch.unbind(v) + assert batch_size == len(vals) + for i, vv in enumerate(vals): + if isinstance(vv, torch.Tensor): + assert vv.numel() == 1, f"info[{k}] must be a scalar tensor, but got {vv.shape}" + vv = vv.item() + batch_kwargs[i]["info"][k] = vv + + items = [BufferItem(**kwargs) for kwargs in batch_kwargs] + return items + + +def zero_pad_sequences(sequences: List[torch.Tensor], side: str = "left") -> torch.Tensor: + assert side in ("left", "right") + max_len = max(seq.size(0) for seq in sequences) + padded_sequences = [] + for seq in sequences: + pad_len = max_len - seq.size(0) + padding = (pad_len, 0) if side == "left" else (0, pad_len) + padded_sequences.append(F.pad(seq, padding)) + return torch.stack(padded_sequences, dim=0) + + +def make_experience_batch(items: List[BufferItem], packing_samples=False) -> Experience: + kwargs = {} + keys = ( + "sequences", + "action_log_probs", + "values", + "returns", + "advantages", + "attention_mask", + "action_mask", + ) + for key in keys: + vals = [getattr(item, key) for item in items] + if not packing_samples: + batch_data = zero_pad_sequences(vals, "left") if vals[0] is not None else None + else: + batch_data = vals if vals[0] is not None else None + kwargs[key] = batch_data + + kwargs["info"] = {} + for key in items[0].info.keys(): + vals = torch.tensor([item.info[key] for item in items]) + kwargs["info"][key] = vals + return Experience(**kwargs) + + +def remove_padding_in_sequences(items): + for item in items: + seq, act_log_prob, value, ret, adv, att_mask, act_mask = ( + item.sequences, + item.action_log_probs, + item.values, + item.returns, + item.advantages, + item.attention_mask, + item.action_mask, + ) + right_pad = (1 - act_mask.long()).sum() + right_pad = None if right_pad == 0 else -right_pad + + # left_pad for seq and att_mask + left_pad = att_mask.long().argmax() + ( + item.sequences, + item.action_log_probs, + item.values, + item.returns, + item.advantages, + item.attention_mask, + item.action_mask, + ) = ( + seq[left_pad:right_pad], + act_log_prob[:right_pad], + value[:right_pad] if item.values is not None else None, + ret[:right_pad], + adv[:right_pad], + att_mask[left_pad:right_pad], + act_mask[:right_pad], + ) + return items + + +class NaiveReplayBuffer(ABC): + """Naive replay buffer class. It stores experience. + + Args: + sample_batch_size (int): Batch size when sampling. + limit (int, optional): Limit of number of experience samples. A number <= 0 means unlimited. Defaults to 0. + cpu_offload (bool, optional): Whether to offload experience to cpu when sampling. Defaults to True. + """ + + def __init__( + self, sample_batch_size: int, limit: int = 0, cpu_offload: bool = True, packing_samples: bool = False + ) -> None: + super().__init__() + self.sample_batch_size = sample_batch_size + # limit <= 0 means unlimited + self.limit = limit + self.cpu_offload = cpu_offload + self.packing_samples = packing_samples + self.target_device = torch.device(f"cuda:{torch.cuda.current_device()}") + self.items: List[BufferItem] = [] + + @torch.no_grad() + def append(self, experience: Experience) -> None: + if self.cpu_offload: + experience.to_device(torch.device("cpu")) + items = split_experience_batch(experience) + # the packed samples comes with no padding + if not self.packing_samples: + items = remove_padding_in_sequences(items) + self.items.extend(items) + if self.limit > 0: + samples_to_remove = len(self.items) - self.limit + if samples_to_remove > 0: + self.items = self.items[samples_to_remove:] + + def clear(self) -> None: + self.items.clear() + + @torch.no_grad() + def sample(self) -> Experience: + items = random.sample(self.items, self.sample_batch_size) + experience = make_experience_batch(items, self.packing_samples) + if self.cpu_offload: + experience.to_device(self.target_device) + return experience + + def __len__(self) -> int: + return len(self.items) + + def __getitem__(self, idx: int) -> BufferItem: + return self.items[idx] + + def collate_fn(self, batch) -> Experience: + experience = make_experience_batch(batch, self.packing_samples) + return experience + + def normalize(self, attribute: str, strategy) -> None: + assert attribute == "advantages" + items = [] + action_masks = [] + for item in self: + items.append(getattr(item, attribute)) + action_masks.append(item.action_mask) + + items_vector = torch.cat(items).float().flatten() + + if action_masks[0] is None: + # packing samples has no action mask + action_masks_vector = 1 + num_actions = items_vector.numel() + else: + action_masks_vector = torch.cat(action_masks).flatten() + num_actions = action_masks_vector.sum() + + # for DP + # mean + sum_and_count = torch.tensor([items_vector.sum(), num_actions], device=items_vector.device) + all_sum, all_count = strategy.all_reduce(sum_and_count, "sum") + mean = all_sum / all_count + # std + std = ((items_vector - mean).pow(2) * action_masks_vector).sum() + all_std = strategy.all_reduce(std, "sum") + rstd = (all_std / all_count).clamp(min=1e-8).rsqrt() + + for i, item in enumerate(self): + setattr(item, attribute, (items[i] - mean) * rstd) diff --git a/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/trainer/prm_trainer.py b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/trainer/prm_trainer.py new file mode 100644 index 0000000000..ce6b3090e0 --- /dev/null +++ b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/trainer/prm_trainer.py @@ -0,0 +1,249 @@ +from abc import ABC + +import torch +from torch.optim import Optimizer +from tqdm import tqdm + +from openrlhf.models import PRMLoss +from openrlhf.utils.distributed_sampler import DistributedSampler +from openrlhf.utils.utils import convert_token_to_id + + +class ProcessRewardModelTrainer(ABC): + """ + Trainer for training a process reward model. + + Args: + model (torch.nn.Module): The model to be trained. + strategy (Strategy): The training strategy to apply. + optim (Optimizer): The optimizer to use during training. + train_dataloader (DataLoader): The dataloader for the training dataset. + eval_dataloader (DataLoader): The dataloader for the evaluation dataset. + scheduler (Scheduler): The learning rate scheduler for dynamic adjustments during training. + tokenizer (Tokenizer): The tokenizer for processing input text data. + max_norm (float, defaults to 0.5): Maximum gradient norm for gradient clipping. + max_epochs (int, defaults to 2): Maximum number of training epochs. + """ + + def __init__( + self, + model, + strategy, + optim: Optimizer, + train_dataloader, + eval_dataloader, + scheduler, + max_norm: float = 1, + batch_size: int = 1, + max_epochs: int = 2, + tokenizer=None, + ) -> None: + super().__init__() + self.strategy = strategy + self.epochs = max_epochs + self.batch_size = batch_size + self.max_norm = max_norm + self.train_dataloader = train_dataloader + self.eval_dataloader = eval_dataloader + self.scheduler = scheduler + self.model = model + self.tokenizer = tokenizer + self.optimizer = optim + self.args = strategy.args + + # set placeholder token + self.placeholder_token_id = convert_token_to_id(strategy.args.placeholder_token, self.tokenizer) + self.reward_token_ids = self.args.reward_tokens + if self.reward_token_ids is not None: + self.reward_token_ids = [convert_token_to_id(token, self.tokenizer) for token in self.reward_token_ids] + + self.ignore_index = -100 + self.loss_fn = PRMLoss(self.placeholder_token_id, self.reward_token_ids) + + # Mixtral 8*7b + self.aux_loss = self.args.aux_loss_coef > 1e-8 + + # packing samples + self.packing_samples = strategy.args.packing_samples + + # wandb setting + self._wandb = None + if self.strategy.args.use_wandb and self.strategy.is_rank_0(): + import wandb + + self._wandb = wandb + if not wandb.api.api_key: + wandb.login(key=strategy.args.use_wandb) + wandb.init( + entity=strategy.args.wandb_org, + project=strategy.args.wandb_project, + group=strategy.args.wandb_group, + name=strategy.args.wandb_run_name, + config=strategy.args.__dict__, + reinit=True, + ) + + wandb.define_metric("train/global_step") + wandb.define_metric("train/*", step_metric="train/global_step", step_sync=True) + wandb.define_metric("eval/global_step") + wandb.define_metric("eval/*", step_metric="eval/global_step", step_sync=True) + + def fit(self, args, consumed_samples=0, num_update_steps_per_epoch=None): + # get eval and save steps + if args.eval_steps == -1: + args.eval_steps = num_update_steps_per_epoch # Evaluate once per epoch + if args.save_steps == -1: + args.save_steps = float("inf") # do not save ckpt + + # Restore step and start_epoch + step = consumed_samples // args.train_batch_size * self.strategy.accumulated_gradient + 1 + start_epoch = consumed_samples // args.train_batch_size // num_update_steps_per_epoch + consumed_samples = consumed_samples % (num_update_steps_per_epoch * args.train_batch_size) + + epoch_bar = tqdm( + range(start_epoch, self.epochs), + desc="Train epoch", + disable=not self.strategy.is_rank_0(), + ) + loss_sum = 0 + acc_sum = 0 + for epoch in range(start_epoch, self.epochs): + if isinstance(self.train_dataloader.sampler, DistributedSampler): + self.train_dataloader.sampler.set_epoch( + epoch, consumed_samples=0 if epoch > start_epoch else consumed_samples + ) + + step_bar = tqdm( + range(self.train_dataloader.__len__()), + desc="Train step of epoch %d" % epoch, + disable=not self.strategy.is_rank_0(), + ) + + # train + self.model.train() + for data in self.train_dataloader: + if not self.packing_samples: + inputs, attention_masks, labels = data + inputs = inputs.to(torch.cuda.current_device()) + attention_mask = attention_masks.to(torch.cuda.current_device()) + labels = labels.to(torch.cuda.current_device()) + packed_seq_lens = None + else: + inputs, attention_masks, packed_seq_lens, labels = data + inputs = inputs.to(torch.cuda.current_device()).squeeze(1) + attention_mask = attention_masks.to(torch.cuda.current_device()).squeeze(1) + labels = labels.to(torch.cuda.current_device()).squeeze(1) + + output = self.model( + inputs, + attention_mask=attention_mask, + return_output=True, + ring_attn_group=self.strategy.ring_attn_group, + packed_seq_lens=packed_seq_lens, + ) + + # mixtral + if self.aux_loss: + aux_loss = output.aux_loss + else: + aux_loss = 0 + + prm_loss, acc = self.loss_fn(inputs, output.logits, labels, return_acc=True) + loss = prm_loss + aux_loss * self.args.aux_loss_coef + self.strategy.backward(loss, self.model, self.optimizer) + self.strategy.optimizer_step(self.optimizer, self.model, self.scheduler) + + loss_sum += loss.item() + acc_sum += acc.item() + logs_dict = { + "prm_loss": prm_loss.item(), + "acc": acc.item(), + "lr": self.scheduler.get_last_lr()[0], + } + if self.aux_loss: + logs_dict["aux_loss"] = aux_loss.item() + # step bar + logs_dict = self.strategy.all_reduce(logs_dict) + step_bar.set_postfix(logs_dict) + step_bar.update() + + # logs/checkpoints/evaluation + if step % self.strategy.accumulated_gradient == 0: + logs_dict["loss_mean"] = loss_sum / self.strategy.accumulated_gradient + logs_dict["acc_mean"] = acc_sum / self.strategy.accumulated_gradient + loss_sum = 0 + acc_sum = 0 + global_step = step // self.strategy.accumulated_gradient + client_states = {"consumed_samples": global_step * args.train_batch_size} + self.save_logs_and_checkpoints(args, global_step, step_bar, logs_dict, client_states) + + step += 1 + + epoch_bar.update() + + # logs/checkpoints/evaluation + def save_logs_and_checkpoints(self, args, global_step, step_bar, logs_dict={}, client_states={}): + if global_step % args.logging_steps == 0: + # wandb + if self._wandb is not None and self.strategy.is_rank_0(): + logs = {"train/%s" % k: v for k, v in {**logs_dict, "global_step": global_step}.items()} + self._wandb.log(logs) + + # eval + if global_step % args.eval_steps == 0: + self.evaluate(self.eval_dataloader, global_step) + # save ckpt + # TODO: save best model on dev, use loss/perplexity on whole dev dataset as metric + if global_step % args.save_steps == 0: + tag = f"global_step{global_step}" + self.strategy.save_ckpt( + self.model.model, args.ckpt_path, tag, args.max_ckpt_num, args.max_ckpt_mem, client_states + ) + + def evaluate(self, eval_dataloader, steps=0): + times = 0 + self.model.eval() + with torch.no_grad(): + loss_sum = 0 + acc_sum = 0 + step_bar = tqdm( + range(eval_dataloader.__len__()), + desc="Eval stage of steps %d" % steps, + disable=not self.strategy.is_rank_0(), + ) + + for data in eval_dataloader: + if not self.packing_samples: + inputs, attention_masks, labels = data + inputs = inputs.to(torch.cuda.current_device()) + attention_mask = attention_masks.to(torch.cuda.current_device()) + labels = labels.to(torch.cuda.current_device()) + packed_seq_lens = None + else: + inputs, attention_masks, packed_seq_lens, labels = data + inputs = inputs.to(torch.cuda.current_device()).squeeze(1) + attention_mask = attention_masks.to(torch.cuda.current_device()).squeeze(1) + labels = labels.to(torch.cuda.current_device()).squeeze(1) + + output = self.model( + inputs, + attention_mask=attention_mask, + return_output=True, + ring_attn_group=self.strategy.ring_attn_group, + packed_seq_lens=packed_seq_lens, + ) + + loss, acc = self.loss_fn(inputs, output.logits, labels, return_acc=True) + + times += 1 + loss_sum += loss.item() + acc_sum += acc.item() + bar_dict = {"eval prm_loss": loss_sum / times, "eval acc": acc_sum / times} + step_bar.update() + logs = self.strategy.all_reduce(bar_dict) + step_bar.set_postfix(logs) + + if self._wandb is not None and self.strategy.is_rank_0(): + logs = {"eval/%s" % k: v for k, v in {**logs, "global_step": steps}.items()} + self._wandb.log(logs) + self.model.train() # reset model state diff --git a/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/trainer/ray/__init__.py b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/trainer/ray/__init__.py new file mode 100644 index 0000000000..4b1e6842ec --- /dev/null +++ b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/trainer/ray/__init__.py @@ -0,0 +1,14 @@ +from .launcher import DistributedTorchRayActor, PPORayActorGroup, ReferenceModelRayActor, RewardModelRayActor +from .ppo_actor import ActorModelRayActor +from .ppo_critic import CriticModelRayActor +from .vllm_engine import create_vllm_engines + +__all__ = [ + "DistributedTorchRayActor", + "PPORayActorGroup", + "ReferenceModelRayActor", + "RewardModelRayActor", + "ActorModelRayActor", + "CriticModelRayActor", + "create_vllm_engines", +] diff --git a/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/trainer/ray/launcher.py b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/trainer/ray/launcher.py new file mode 100644 index 0000000000..492ec9383f --- /dev/null +++ b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/trainer/ray/launcher.py @@ -0,0 +1,316 @@ +import logging +import os +import socket +from typing import Callable, Dict, List, Optional, Type + +import ray +import torch +from ray.util.placement_group import PlacementGroup, placement_group +from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy + +from openrlhf.models import Actor, get_llm_for_sequence_regression +from openrlhf.utils.deepspeed import DeepspeedStrategy + +from openrlhf.trainer.ray.utils import ray_noset_visible_devices + +class DistributedTorchRayActor: + def __init__(self, world_size, rank, master_addr, master_port): + logging.basicConfig( + format="%(asctime)s %(levelname)-8s %(message)s", + level=logging.INFO, + datefmt="%Y-%m-%d %H:%M:%S", + ) + self._world_size = world_size + self._rank = rank + self._master_addr = master_addr if master_addr else self._get_current_node_ip() + self._master_port = master_port if master_port else self._get_free_port() + os.environ["MASTER_ADDR"] = self._master_addr + os.environ["MASTER_PORT"] = str(self._master_port) + os.environ["WORLD_SIZE"] = str(self._world_size) + os.environ["RANK"] = str(self._rank) + # NOTE: Ray will automatically set the *_VISIBLE_DEVICES + # environment variable for each actor, unless + # RAY_EXPERIMENTAL_NOSET_*_VISIBLE_DEVICES is set, so + # set local rank to 0 when the flag is not applicable. + os.environ["LOCAL_RANK"] = str(ray.get_gpu_ids()[0]) if ray_noset_visible_devices() else "0" + + @staticmethod + def _get_current_node_ip(): + address = ray._private.services.get_node_ip_address() + # strip ipv6 address + return address.strip("[]") + + @staticmethod + def _get_free_port(): + with socket.socket() as sock: + sock.bind(("", 0)) + return sock.getsockname()[1] + + def get_master_addr_port(self): + return self._master_addr, self._master_port + + +class BasePPORole(DistributedTorchRayActor): + def _setup_distributed(self, strategy: DeepspeedStrategy): + # configure strategy + self.strategy = strategy + strategy.setup_distributed() + + def init_model_from_pretrained(self, *args, **kwargs): + raise NotImplementedError() + + +@ray.remote(num_gpus=1) +class ReferenceModelRayActor(BasePPORole): + def init_model_from_pretrained(self, strategy: DeepspeedStrategy, pretrain): + self._setup_distributed(strategy) + model = Actor( + pretrain, + use_flash_attention_2=strategy.args.flash_attn, + bf16=strategy.args.bf16, + load_in_4bit=strategy.args.load_in_4bit, + ds_config=strategy.get_ds_eval_config(offload=strategy.args.ref_reward_offload), + packing_samples=strategy.args.packing_samples, + ) + strategy.print(model) + + if strategy.args.ref_reward_offload: + model._offload = True + + self.model = self.strategy.prepare(model, is_rlhf=True) + self.model.eval() + + def forward( + self, + sequences: torch.LongTensor, + num_actions: int = None, + attention_mask: Optional[torch.Tensor] = None, + return_output=False, + packed_seq_lens: Optional[list[int]] = None, + ) -> torch.Tensor: + device = torch.cuda.current_device() + with torch.no_grad(): + log_probs = self.model( + sequences.to(device), + num_actions, + attention_mask.to(device), + return_output=return_output, + packed_seq_lens=packed_seq_lens, + ) + return log_probs.to("cpu") + + def empty_cache(self) -> None: + torch.cuda.empty_cache() + + +@ray.remote(num_gpus=1) +class RewardModelRayActor(BasePPORole): + def init_model_from_pretrained(self, strategy: DeepspeedStrategy, pretrain): + self._setup_distributed(strategy) + model = get_llm_for_sequence_regression( + pretrain, + "reward", + normalize_reward=strategy.args.normalize_reward, + use_flash_attention_2=strategy.args.flash_attn, + bf16=strategy.args.bf16, + load_in_4bit=strategy.args.load_in_4bit, + ds_config=strategy.get_ds_eval_config(offload=strategy.args.ref_reward_offload), + value_head_prefix=strategy.args.value_head_prefix, + packing_samples=strategy.args.packing_samples, + ) + strategy.print(model) + strategy.print("reward normalization status: {}".format(strategy.args.normalize_reward)) + strategy.print("mean: {}, std {}".format(model.mean, model.std)) + + if strategy.args.ref_reward_offload: + model._offload = True + + self.model = self.strategy.prepare(model, is_rlhf=True) + self.model.eval() + + def forward( + self, sequences: torch.LongTensor, attention_mask: Optional[torch.Tensor] = None, packed_seq_lens=None + ) -> torch.Tensor: + device = torch.cuda.current_device() + with torch.no_grad(): + reward = self.model(sequences.to(device), attention_mask.to(device), packed_seq_lens=packed_seq_lens) + return reward.to("cpu") + + def empty_cache(self) -> None: + torch.cuda.empty_cache() + + +class PPORayActorGroup: + """ + A group of ray actors + Functions start with 'async' should return list of object refs + + Args: + num_nodes (int): Number of nodes for this actor group. + num_gpus_per_node (int): Number of gpus for this actor group. + ray_actor_type (Type[BasePPORole]): PPO model type that this actor group serve on. + pg (PlacementGroup, optional): Placement group to schedule actor on. + If none, create new placement group automatically. Defaults to None. + num_gpus_per_actor (float, optional): Number of gpus allocated for each actor. + If < 1.0, multiple models can share same gpu. Defaults to 1. + """ + + def __init__( + self, + num_nodes, + num_gpus_per_node, + ray_actor_type: Type[BasePPORole], + pg: PlacementGroup = None, + num_gpus_per_actor=1, + resources: Dict[str, float] = None, + num_resources_per_node: int = None, + ) -> None: + self._num_nodes = num_nodes + self._num_gpus_per_node = num_gpus_per_node + self.ray_actor_type = ray_actor_type + + # custom resources, see https://docs.ray.io/en/latest/ray-core/scheduling/resources.html + self._resources = resources + self._num_resources_per_node = num_resources_per_node + + self._initiate_actors(pg, num_gpus_per_actor) + + def _initiate_actors(self, pg, num_gpus_per_actor): + world_size = self._num_nodes * self._num_gpus_per_node + + # Use placement group to lock resources for models of same type + if self._num_gpus_per_node > 1 and pg is None: + bundles = [ + {"GPU": self._num_gpus_per_node, "CPU": self._num_gpus_per_node} for _ in range(self._num_nodes) + ] + if self._resources: + resources_name = list(self._resources.keys())[0] + for i in range(len(bundles)): + bundles[i][resources_name] = self._num_resources_per_node + + pg = placement_group(bundles, strategy="STRICT_SPREAD") + ray.get(pg.ready()) + if pg: + master_actor = self.ray_actor_type.options( + num_cpus=num_gpus_per_actor, + num_gpus=num_gpus_per_actor, + resources=self._resources, + scheduling_strategy=PlacementGroupSchedulingStrategy( + placement_group=pg, placement_group_bundle_index=0 + ), + ).remote(world_size, 0, None, None) + else: + master_actor = self.ray_actor_type.options( + num_cpus=num_gpus_per_actor, + num_gpus=num_gpus_per_actor, + resources=self._resources, + ).remote(world_size, 0, None, None) + self._actor_handlers = [master_actor] + + # Create worker actors + if world_size > 1: + master_addr, master_port = ray.get(master_actor.get_master_addr_port.remote()) + for rank in range(1, world_size): + if pg: + worker_actor = self.ray_actor_type.options( + num_cpus=num_gpus_per_actor, + num_gpus=num_gpus_per_actor, + resources=self._resources, + scheduling_strategy=PlacementGroupSchedulingStrategy( + placement_group=pg, + placement_group_bundle_index=rank // self._num_gpus_per_node, + ), + ).remote(world_size, rank, master_addr, master_port) + else: + worker_actor = self.ray_actor_type.options( + num_cpus=num_gpus_per_actor, + num_gpus=num_gpus_per_actor, + resources=self._resources, + ).remote(world_size, rank, master_addr, master_port) + self._actor_handlers.append(worker_actor) + + def async_init_model_from_pretrained( + self, + *args, + **kwargs, + ): + """Init model from pretrained checkpoint. + + Returns: + List: list of remote object refs. + """ + return [actor.init_model_from_pretrained.remote(*args, **kwargs) for actor in self._actor_handlers] + + def async_fit_actor_model( + self, + critic_model_group: "PPORayActorGroup", + initial_model_group: "PPORayActorGroup", + reward_model_groups: List["PPORayActorGroup"], + remote_rm_urls: List[str] = None, + reward_fn: Callable[[List[torch.Tensor]], torch.Tensor] = None, + vllm_engines: List = None, + ): + """Train actor model. + + Args: + critic_model_group (PPORayActorGroup): critic model group. + initial_model_group (PPORayActorGroup): reference model group. + reward_model_groups (PPORayActorGroup): reward model groups. + remote_rm_urls: remote RM APIs. + reward_fn: reward calculate function, must be specified if using multiple reward models. + vllm_engines: vllm engines for text generation, if not specified, generate text by actor model directly. + + Returns: + List: list of remote object refs. + """ + assert ( + (remote_rm_urls and len(remote_rm_urls) == 1) + or (reward_model_groups and len(reward_model_groups) == 1) + or reward_fn is not None + ), "reward_fn must be specified if using multiple reward models" + + critic_actors = critic_model_group._actor_handlers if critic_model_group else None + initial_actors = initial_model_group._actor_handlers + + refs = [] + # TODO(wuxibin): actor model choose critic/reward/initial model in a + # round robin fashion, implement more efficient dispatching strategy. + for i, actor in enumerate(self._actor_handlers): + critic_actor = critic_actors[i % len(critic_actors)] if critic_actors else None + initial_actor = initial_actors[i % len(initial_actors)] + + reward_actors = [] + if not remote_rm_urls: + for reward_model_group in reward_model_groups: + actors = reward_model_group._actor_handlers + reward_actors.append(actors[i % len(actors)]) + + refs.append( + actor.fit.remote( + critic_model=critic_actor, + initial_model=initial_actor, + reward_model=reward_actors, + remote_rm_url=remote_rm_urls, + reward_fn=reward_fn, + vllm_engines=vllm_engines, + # whether this actor should triger corresponding critic model training + critic_train_remote=(i < len(critic_actors)) if critic_actor else None, + ) + ) + + return refs + + def async_save_model(self): + """Save actor model on rank 0. + + Returns: + List: list of remote object refs. + """ + return [actor.save_model.remote() for actor in self._actor_handlers] + + def async_run_method(self, method_name, *args, **kwargs): + refs = [] + for actor in self._actor_handlers: + method = getattr(actor, method_name) + refs.append(method.remote(*args, **kwargs)) + return refs diff --git a/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/trainer/ray/ppo_actor.py b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/trainer/ray/ppo_actor.py new file mode 100644 index 0000000000..9661b0edb0 --- /dev/null +++ b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/trainer/ray/ppo_actor.py @@ -0,0 +1,448 @@ +import itertools +import math +import os +import socket +from typing import Callable, Dict, List + +import deepspeed +import ray +import torch +import torch.distributed +from transformers.trainer import get_scheduler + +from openrlhf.datasets import PromptDataset, SFTDataset +from openrlhf.models import Actor +from openrlhf.trainer import PPOTrainer +from openrlhf.trainer.ppo_utils import Experience, RemoteExperienceMaker +from openrlhf.utils import blending_datasets, get_tokenizer +from openrlhf.utils.deepspeed import DeepspeedStrategy +from openrlhf.utils.distributed_util import init_process_group + +from .launcher import BasePPORole + + +class ActorPPOTrainer(PPOTrainer): + def __init__( + self, + *args, + vllm_engines: List = None, + remote_rm_url: List[str] = None, + critic_train_remote: bool = False, + **kwargs, + ): + """PPOTrainer for ray. + + Args: + vllm_engines (List, optional): vllm engines for text generation, if not specified, generate text by actor model directly. Defaults to None. + critic_train_remote (bool, optional): whether this actor should triger corresponding critic model training. Defaults to False. + """ + super().__init__(*args, **kwargs) + self.remote_rm_url = remote_rm_url + self.vllm_engines = vllm_engines + self.critic_train_remote = critic_train_remote + + self.experience_maker = RemoteExperienceMaker( + self.actor, + self.critic, + self.reward_model, + self.initial_model, + self.tokenizer, + self.prompt_max_len, + self.kl_ctl, + self.strategy, + self.remote_rm_url, + self.reward_fn, + vllm_engines=self.vllm_engines, + packing_samples=self.strategy.args.packing_samples, + ) + + # Create torch group with deepspeed rank 0 and all vllm ranks + # to update vllm engine's weights after each training stage. + # + # Say we have 3 vllm engines and eache of them has 4 GPUs, + # then the torch group is: + # [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12] + # |ds rank 0 | engine-0 | engine-1 | engine-2 | + # + # For ZeRO-1/2: + # 1. Broadcast parameters from rank 0 to all vllm engines + # For ZeRO-3: + # 1. AllGather paramters to rank 0 + # 2. Broadcast parameters from rank 0 to all vllm engines + if self.vllm_engines is not None and torch.distributed.get_rank() == 0: + master_address = ray._private.services.get_node_ip_address() + with socket.socket() as sock: + sock.bind(("", 0)) + master_port = sock.getsockname()[1] + + vllm_num_engines, vllm_tensor_parallel_size = ( + self.strategy.args.vllm_num_engines, + self.strategy.args.vllm_tensor_parallel_size, + ) + world_size = vllm_num_engines * vllm_tensor_parallel_size + 1 + + backend = getattr(self.strategy.args, "vllm_sync_backend", "nccl") + use_ray = getattr(self.strategy.args, "vllm_sync_with_ray", False) + group_name = "openrlhf" + refs = [ + engine.init_process_group.remote( + master_address, + master_port, + i * vllm_tensor_parallel_size + 1, + world_size, + group_name, + backend=backend, + use_ray=use_ray, + ) + for i, engine in enumerate(self.vllm_engines) + ] + if use_ray: + import ray.util.collective as collective + collective.init_collective_group( + world_size=world_size, + rank=0, + backend=backend, + group_name=group_name + ) + self._model_update_group = group_name + else: + self._model_update_group = init_process_group( + backend=backend, + init_method=f"tcp://{master_address}:{master_port}", + world_size=world_size, + rank=0, + group_name=group_name, + ) + + ray.get(refs) + + torch.distributed.barrier() + + def ppo_train(self, global_steps): + # 1. ensure all experience makers done + self.experience_maker.flush() + torch.distributed.barrier() + + # 2. triger remote critic model training + if self.critic_train_remote: + critic_status_ref = self.critic.fit.remote() + + # 3. actor model training + if global_steps > self.freezing_actor_steps: + status = super().ppo_train(global_steps) + + # 4. broadcast weights to vllm engines + if self.vllm_engines is not None: + torch.distributed.barrier() + self._broadcast_to_vllm() + else: + status = {} + + # 5. wait remote critic model training done + if self.critic_train_remote: + status.update(ray.get(critic_status_ref)) + torch.distributed.barrier() + + return status + + def training_step(self, experience: Experience, global_steps) -> Dict[str, float]: + return self.training_step_actor(experience) + + def _broadcast_to_vllm(self): + use_prefix_cache = getattr(self.strategy.args, "enable_prefix_caching", False) + cache_reset_refs = [] + if use_prefix_cache and torch.distributed.get_rank() == 0: + # clear prefix cache + for engine in self.vllm_engines: + cache_reset_refs.append(engine.reset_prefix_cache.remote()) + # avoid OOM + torch.cuda.empty_cache() + use_ray = getattr(self.strategy.args, "vllm_sync_with_ray", False) + model = self.actor.model.module + count, num_params = 0, len(list(model.named_parameters())) + for name, param in model.named_parameters(): + count += 1 # empty_cache at last param + + # Fire all vllm engines for broadcast + if torch.distributed.get_rank() == 0: + shape = param.shape if self.strategy.args.zero_stage != 3 else param.ds_shape + refs = [ + engine.update_weight.remote(name, dtype=param.dtype, shape=shape, empty_cache=count == num_params) + for engine in self.vllm_engines + ] + + # For ZeRO-3, allgather sharded parameter and broadcast to all vllm engines by rank 0 + with deepspeed.zero.GatheredParameters([param], enabled=self.strategy.args.zero_stage == 3): + if torch.distributed.get_rank() == 0: + if use_ray: + import ray.util.collective as collective + collective.broadcast(param.data, 0, group_name=self._model_update_group) + else: + torch.distributed.broadcast(param.data, 0, group=self._model_update_group) + ray.get(refs) + if cache_reset_refs: + ray.get(cache_reset_refs) + torch.distributed.barrier() + + def _save_checkpoint(self, args, tag, client_states): + # call remote critic + if not self.disable_ds_ckpt: + if self.critic_train_remote: + ref = self.critic.save_checkpoint.remote(tag) + self.strategy.save_ckpt( + self.actor.model, + os.path.join(args.ckpt_path, "_actor"), + tag, + args.max_ckpt_num, + args.max_ckpt_mem, + client_states, + ) + if self.save_hf_ckpt: + save_path = os.path.join(args.ckpt_path, f"{tag}_hf") + self.strategy.save_model( + self.ema_model if args.enable_ema else self.actor, + self.tokenizer, + save_path, + ) + # wait + if not self.disable_ds_ckpt: + if self.critic_train_remote: + ray.get(ref) + torch.distributed.barrier() + + +@ray.remote(num_gpus=1) +class ActorModelRayActor(BasePPORole): + def init_model_from_pretrained(self, strategy: DeepspeedStrategy, pretrain): + args = strategy.args + + if getattr(args, "vllm_num_engines", 0) > 0: + # To prevent hanging during NCCL synchronization of weights between DeepSpeed and vLLM. + # see https://github.com/vllm-project/vllm/blob/c6b0a7d3ba03ca414be1174e9bd86a97191b7090/vllm/worker/worker_base.py#L445 + if getattr(args, "vllm_sync_backend", "nccl") == "nccl": + os.environ["NCCL_CUMEM_ENABLE"] = "0" + + self._setup_distributed(strategy) + + actor = Actor( + pretrain, + use_flash_attention_2=strategy.args.flash_attn, + bf16=strategy.args.bf16, + load_in_4bit=strategy.args.load_in_4bit, + lora_rank=strategy.args.lora_rank, + lora_alpha=strategy.args.lora_alpha, + target_modules=strategy.args.target_modules, + lora_dropout=strategy.args.lora_dropout, + ds_config=strategy.get_ds_train_config(is_actor=True), + packing_samples=strategy.args.packing_samples, + ) + strategy.print(actor) + + # configure tokenizer + self.tokenizer = get_tokenizer( + pretrain, actor.model, "left", strategy, use_fast=not strategy.args.disable_fast_tokenizer + ) + + if args.enable_ema: + ema_model = Actor( + pretrain, + use_flash_attention_2=strategy.args.flash_attn, + bf16=strategy.args.bf16, + load_in_4bit=strategy.args.load_in_4bit, + ds_config=strategy.get_ds_eval_config(offload=True), + packing_samples=strategy.args.packing_samples, + ) + else: + ema_model = None + + # configure optimizer + actor_optim = strategy.create_optimizer( + actor, lr=args.actor_learning_rate, betas=strategy.args.adam_betas, weight_decay=args.l2 + ) + + # prepare_datasets + self.prepare_datasets() + + # configure scheduler + self.num_update_steps_per_episodes = ( + len(self.prompts_dataset) * args.n_samples_per_prompt // args.train_batch_size * args.max_epochs + ) + max_steps = math.ceil(args.num_episodes * self.num_update_steps_per_episodes) + self._max_steps = max_steps + + actor_scheduler = get_scheduler( + "cosine_with_min_lr", + actor_optim, + num_warmup_steps=math.ceil(max_steps * args.lr_warmup_ratio), + num_training_steps=max_steps, + scheduler_specific_kwargs={"min_lr": args.actor_learning_rate * 0.1}, + ) + + if args.gradient_checkpointing: + actor.gradient_checkpointing_enable( + gradient_checkpointing_kwargs={"use_reentrant": args.gradient_checkpointing_use_reentrant} + ) + + # prepare models/optimizers... + self.actor, self.actor_optim, self.actor_scheduler = strategy.prepare( + (actor, actor_optim, actor_scheduler), + is_rlhf=True, + ) + + if ema_model: + ema_model._offload = True + self.ema_model = strategy.prepare(ema_model, is_rlhf=True) + else: + self.ema_model = None + + # load checkpoint + self.consumed_samples = 0 + ckpt_path = os.path.join(args.ckpt_path, "_actor") + if args.load_checkpoint and os.path.exists(ckpt_path): + _, states = strategy.load_ckpt(self.actor.model, ckpt_path) + self.consumed_samples = states["consumed_samples"] + strategy.print(f"Loaded the checkpoint: {ckpt_path}, consumed_samples: {self.consumed_samples}") + + def prepare_datasets(self): + strategy = self.strategy + args = self.strategy.args + + # prepare datasets + prompts_data = blending_datasets( + args.prompt_data, + args.prompt_data_probs, + strategy, + args.seed, + max_count=args.max_samples, + return_eval=False, + train_split=args.prompt_split, + ) + prompts_data = prompts_data.select(range(min(args.max_samples, len(prompts_data)))) + self.prompts_dataset = PromptDataset( + prompts_data, self.tokenizer, strategy, input_template=args.input_template + ) + self.prompts_dataloader = strategy.setup_dataloader( + self.prompts_dataset, args.rollout_batch_size // strategy.world_size, True, True + ) + + if args.pretrain_data: + pretrain_data = blending_datasets( + args.pretrain_data, + args.pretrain_data_probs, + strategy, + args.seed, + return_eval=False, + train_split=args.pretrain_split, + ) + pretrain_max_len = args.max_len if args.max_len else args.prompt_max_len + args.generate_max_len + pretrain_dataset = SFTDataset( + pretrain_data.select( + range( + min( + len(pretrain_data), args.max_epochs * len(self.prompts_dataset) * args.n_samples_per_prompt + ) + ) + ), + self.tokenizer, + pretrain_max_len, + strategy, + pretrain_mode=True, + ) + self.pretrain_dataloader = itertools.cycle( + iter( + strategy.setup_dataloader( + pretrain_dataset, + args.micro_train_batch_size, + True, + True, + pretrain_dataset.collate_fn, + ) + ) + ) + else: + self.pretrain_dataloader = None + + def max_steps(self): + """Return the maximum number of steps.""" + return self._max_steps + + def fit( + self, + critic_model: ray.actor.ActorHandle, + initial_model: ray.actor.ActorHandle, + reward_model: List[ray.actor.ActorHandle], + remote_rm_url: List[str] = None, + reward_fn: Callable[[List[torch.Tensor]], torch.Tensor] = None, + vllm_engines: List[ray.actor.ActorHandle] = None, + critic_train_remote: bool = False, + ): + """Train actor model with prompt datasets.""" + strategy = self.strategy + args = self.strategy.args + + # configure Trainer + trainer = ActorPPOTrainer( + strategy, + self.actor, + critic_model, + reward_model, + initial_model, + ema_model=self.ema_model, + actor_optim=None, + critic_optim=None, + actor_scheduler=self.actor_scheduler, + critic_scheduler=None, + remote_rm_url=remote_rm_url, + reward_fn=reward_fn, + vllm_engines=vllm_engines, + max_epochs=args.max_epochs, + micro_train_batch_size=args.micro_train_batch_size, + micro_rollout_batch_size=args.micro_rollout_batch_size, + gradient_checkpointing=args.gradient_checkpointing, + critic_train_remote=critic_train_remote, + tokenizer=self.tokenizer, + prompt_max_len=args.prompt_max_len, + value_clip=args.value_clip, + eps_clip=args.eps_clip, + gamma=args.gamma, + lambd=args.lambd, + init_kl_coef=args.init_kl_coef, + kl_target=args.kl_target, + ema_beta=0.992, + ptx_coef=args.ptx_coef, + max_norm=args.max_norm, + # fro GPT generation + do_sample=True, + max_new_tokens=args.generate_max_len, + max_length=args.max_len, + temperature=args.temperature, + top_p=args.top_p, + pad_token_id=self.tokenizer.pad_token_id, + eos_token_id=self.tokenizer.eos_token_id, + save_hf_ckpt=args.save_hf_ckpt, + disable_ds_ckpt=args.disable_ds_ckpt, + ) + + # broadcast checkpoint + ckpt_path = os.path.join(args.ckpt_path, "_actor") + if args.load_checkpoint and os.path.exists(ckpt_path) and not vllm_engines is None: + torch.distributed.barrier() + trainer._broadcast_to_vllm() + + trainer.fit( + args, + self.prompts_dataloader, + self.pretrain_dataloader, + self.consumed_samples, + self.num_update_steps_per_episodes, + ) + + def save_model(self): + args = self.strategy.args + + # save model checkpoint after fitting on only rank0 + self.strategy.save_model( + self.ema_model if args.enable_ema else self.actor, + self.tokenizer, + args.save_path, + ) diff --git a/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/trainer/ray/ppo_critic.py b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/trainer/ray/ppo_critic.py new file mode 100644 index 0000000000..d4a40b0843 --- /dev/null +++ b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/trainer/ray/ppo_critic.py @@ -0,0 +1,197 @@ +import math +import os +from typing import Dict, Optional, Union + +import ray +import torch +from torch.utils.data import DataLoader +from tqdm import tqdm +from transformers.trainer import get_scheduler + +from openrlhf.models import get_llm_for_sequence_regression +from openrlhf.trainer import PPOTrainer +from openrlhf.trainer.ppo_utils import Experience +from openrlhf.utils import get_tokenizer +from openrlhf.utils.deepspeed import DeepspeedStrategy + +from .launcher import BasePPORole + + +class CriticPPOTrainer(PPOTrainer): + def ppo_train(self): + # replay buffer may be empty at first, we should rebuild at each training + dataloader = DataLoader( + self.replay_buffer, + batch_size=self.replay_buffer.sample_batch_size, + shuffle=True, + drop_last=True, + pin_memory=self.dataloader_pin_memory, + collate_fn=self.replay_buffer.collate_fn, + ) + device = torch.cuda.current_device() + + status_list = [] + status_mean = {} + for epoch in range(self.max_epochs): + pbar = tqdm( + dataloader, + desc=f"Train epoch [{epoch + 1}/{self.max_epochs}]", + disable=not self.strategy.is_rank_0(), + ) + for experience in pbar: + experience.to_device(device) + status = self.training_step(experience) + + # for DP + status = self.strategy.all_reduce(status) + + status_list.append(status) + pbar.set_postfix(status) + + if status_list: + status_mean = status_list[0] + for m in status_list[1:]: + for k, v in m.items(): + status_mean[k] += v + for k in status_mean.keys(): + status_mean[k] /= len(status_list) + return status_mean + + def training_step(self, experience: Experience) -> Dict[str, float]: + return self.training_step_critic(experience) + + +@ray.remote(num_gpus=1) +class CriticModelRayActor(BasePPORole): + def init_model_from_pretrained(self, strategy: DeepspeedStrategy, pretrain, max_steps): + args = strategy.args + + self._setup_distributed(strategy) + critic = get_llm_for_sequence_regression( + pretrain, + "critic", + normalize_reward=strategy.args.normalize_reward, + use_flash_attention_2=strategy.args.flash_attn, + bf16=strategy.args.bf16, + load_in_4bit=strategy.args.load_in_4bit, + lora_rank=strategy.args.lora_rank, + lora_alpha=strategy.args.lora_alpha, + target_modules=strategy.args.target_modules, + lora_dropout=strategy.args.lora_dropout, + ds_config=strategy.get_ds_train_config(is_actor=False), + value_head_prefix=strategy.args.value_head_prefix, + init_value_head=strategy.args.pretrain == strategy.args.critic_pretrain, + packing_samples=strategy.args.packing_samples, + ) + strategy.print(critic) + strategy.print("reward normalization status: {}".format(strategy.args.normalize_reward)) + strategy.print("mean: {}, std {}".format(critic.mean, critic.std)) + + # configure tokenizer + if strategy.args.save_value_network: + self.tokenizer = get_tokenizer( + pretrain, critic, "left", strategy, use_fast=not strategy.args.disable_fast_tokenizer + ) + + # configure optimizer + critic_optim = strategy.create_optimizer( + critic, lr=args.critic_learning_rate, betas=args.adam_betas, weight_decay=args.l2 + ) + + # configure scheduler + critic_scheduler = get_scheduler( + "cosine_with_min_lr", + critic_optim, + num_warmup_steps=math.ceil(max_steps * args.lr_warmup_ratio), + num_training_steps=max_steps, + scheduler_specific_kwargs={"min_lr": args.critic_learning_rate * 0.1}, + ) + + if args.gradient_checkpointing: + critic.gradient_checkpointing_enable( + gradient_checkpointing_kwargs={"use_reentrant": args.gradient_checkpointing_use_reentrant} + ) + + # prepare models/optimizers... + self.critic, self.critic_optim, self.critic_scheduler = strategy.prepare( + (critic, critic_optim, critic_scheduler), + is_rlhf=True, + ) + + # load checkpoint + if args.load_checkpoint and os.path.exists(os.path.join(args.ckpt_path, "_actor")): + ckpt_path = os.path.join(args.ckpt_path, "_critic") + strategy.load_ckpt(self.critic, ckpt_path) + strategy.print(f"Loaded the checkpoint: {ckpt_path}") + + # configure Trainer + # only use wandb at actor model + strategy.args.use_wandb = False + self.trainer = CriticPPOTrainer( + strategy, + actor=None, + critic=self.critic, + reward_model=None, + initial_model=None, + ema_model=None, + actor_optim=None, + critic_optim=self.critic_optim, + actor_scheduler=None, + critic_scheduler=self.critic_scheduler, + max_epochs=args.max_epochs, + micro_train_batch_size=args.micro_train_batch_size, + micro_rollout_batch_size=args.micro_rollout_batch_size, + gradient_checkpointing=args.gradient_checkpointing, + prompt_max_len=args.prompt_max_len, + value_clip=args.value_clip, + eps_clip=args.eps_clip, + ) + + def forward( + self, + sequences: torch.LongTensor, + num_actions: Optional[Union[int, list[int]]] = None, + attention_mask: Optional[torch.Tensor] = None, + packed_seq_lens=None, + ) -> torch.Tensor: + """Generates critic values.""" + device = torch.cuda.current_device() + self.critic.eval() + with torch.no_grad(): + value = self.critic( + sequences.to(device), num_actions, attention_mask.to(device), packed_seq_lens=packed_seq_lens + ) + self.critic.train() # reset model state + return value.to("cpu") + + def append(self, experience): + """Append experience to replay buffer.""" + self.trainer.replay_buffer.append(experience) + + def fit(self): + """Train critic model with the replay buffer.""" + torch.cuda.empty_cache() + self.critic.train() + status = self.trainer.ppo_train() + self.trainer.replay_buffer.clear() + torch.cuda.empty_cache() + return status + + def empty_cache(self) -> None: + torch.cuda.empty_cache() + + def save_model(self): + args = self.strategy.args + + # save model checkpoint after fitting on only rank0 + self.strategy.save_model( + self.critic, + self.tokenizer, + args.save_path + "_critic", + ) + + def save_checkpoint(self, tag): + args = self.strategy.args + self.strategy.save_ckpt( + self.critic, os.path.join(args.ckpt_path, "_critic"), tag, args.max_ckpt_num, args.max_ckpt_mem + ) diff --git a/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/trainer/ray/utils.py b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/trainer/ray/utils.py new file mode 100644 index 0000000000..614ce63bd2 --- /dev/null +++ b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/trainer/ray/utils.py @@ -0,0 +1,22 @@ +import os + + +def ray_noset_visible_devices(env_vars=os.environ): + # Refer to + # https://github.com/ray-project/ray/blob/161849364a784442cc659fb9780f1a6adee85fce/python/ray/_private/accelerators/nvidia_gpu.py#L95-L96 + # https://github.com/ray-project/ray/blob/161849364a784442cc659fb9780f1a6adee85fce/python/ray/_private/accelerators/amd_gpu.py#L102-L103 + # https://github.com/ray-project/ray/blob/161849364a784442cc659fb9780f1a6adee85fce/python/ray/_private/accelerators/npu.py#L94-L95 + # https://github.com/ray-project/ray/blob/161849364a784442cc659fb9780f1a6adee85fce/python/ray/_private/accelerators/hpu.py#L116-L117 + # https://github.com/ray-project/ray/blob/161849364a784442cc659fb9780f1a6adee85fce/python/ray/_private/accelerators/neuron.py#L108-L109 + # https://github.com/ray-project/ray/blob/161849364a784442cc659fb9780f1a6adee85fce/python/ray/_private/accelerators/tpu.py#L171-L172 + # https://github.com/ray-project/ray/blob/161849364a784442cc659fb9780f1a6adee85fce/python/ray/_private/accelerators/intel_gpu.py#L97-L98 + NOSET_VISIBLE_DEVICES_ENV_VARS_LIST = [ + "RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES", + "RAY_EXPERIMENTAL_NOSET_ROCR_VISIBLE_DEVICES", + "RAY_EXPERIMENTAL_NOSET_ASCEND_RT_VISIBLE_DEVICES", + "RAY_EXPERIMENTAL_NOSET_HABANA_VISIBLE_MODULES", + "RAY_EXPERIMENTAL_NOSET_NEURON_RT_VISIBLE_CORES", + "RAY_EXPERIMENTAL_NOSET_TPU_VISIBLE_CHIPS", + "RAY_EXPERIMENTAL_NOSET_ONEAPI_DEVICE_SELECTOR", + ] + return any(env_vars.get(env_var) for env_var in NOSET_VISIBLE_DEVICES_ENV_VARS_LIST) diff --git a/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/trainer/ray/vllm_engine.py b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/trainer/ray/vllm_engine.py new file mode 100644 index 0000000000..889b034242 --- /dev/null +++ b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/trainer/ray/vllm_engine.py @@ -0,0 +1,149 @@ +import ray +from ray.util.placement_group import placement_group +from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy + +from openrlhf.trainer.ray.utils import ray_noset_visible_devices + +from openrlhf.utils.logging_utils import init_logger + +logger = init_logger(__name__) + + +@ray.remote +def get_all_env_variables(): + import os + + return os.environ + + +@ray.remote +class LLMRayActor: + def __init__(self, *args, **kwargs): + import vllm + + self.__version__ = vllm.__version__ + assert self.__version__ >= "0.4.2", "OpenRLHF only supports vLLM >= 0.4.2" + + noset_visible_devices = kwargs.pop("noset_visible_devices", False) + self.use_gpu_executor = kwargs["tensor_parallel_size"] == 1 and not noset_visible_devices + + # See https://github.com/vllm-project/vllm/blob/main/vllm/executor/gpu_executor.py + if self.use_gpu_executor: + from openrlhf.trainer.ray.vllm_worker_wrap import WorkerWrap + + vllm.worker.worker.Worker = WorkerWrap + else: + # RayGPUExecutor + # See the patch https://github.com/vllm-project/vllm/commit/479d69fad0538f04cb22bf13e76ff91cfeb8a4e5 + if vllm.__version__ >= "0.4.3": + # https://github.com/vllm-project/vllm/commit/676a99982fe9aabe72fd52a91e08988a653a7359 + kwargs["distributed_executor_backend"] = "ray" + else: + kwargs["worker_use_ray"] = True + + if vllm.__version__ > "0.6.4.post1": + # https://github.com/vllm-project/vllm/pull/10555 + kwargs["worker_cls"] = "openrlhf.trainer.ray.vllm_worker_wrap.WorkerWrap" + else: + RayWorkerWrapperPath = vllm.executor.ray_utils + + class RayWorkerWrapper(RayWorkerWrapperPath.RayWorkerWrapper): + def __init__(self, *args, **kwargs) -> None: + kwargs["worker_module_name"] = "openrlhf.trainer.ray.vllm_worker_wrap" + kwargs["worker_class_name"] = "WorkerWrap" + super().__init__(*args, **kwargs) + + RayWorkerWrapperPath.RayWorkerWrapper = RayWorkerWrapper + + self.llm = vllm.LLM(*args, **kwargs) + + def generate(self, *args, **kwargs): + return self.llm.generate(*args, **kwargs) + + def init_process_group(self, master_address, master_port, rank_offset, world_size, group_name, backend, use_ray): + if self.use_gpu_executor: + return self.llm.llm_engine.model_executor.driver_worker.init_process_group( + master_address, master_port, rank_offset, world_size, group_name, backend, use_ray + ) + else: + return self.llm.llm_engine.model_executor._run_workers( + "init_process_group", master_address, master_port, rank_offset, world_size, group_name, backend, use_ray + ) + + def update_weight(self, name, dtype, shape, empty_cache=False): + self.stop_remote_worker_execution_loop() + + if self.use_gpu_executor: + return self.llm.llm_engine.model_executor.driver_worker.update_weight(name, dtype, shape, empty_cache) + else: + return self.llm.llm_engine.model_executor._run_workers("update_weight", name, dtype, shape, empty_cache) + + def reset_prefix_cache(self): + import vllm + if vllm.__version__ < "0.7.0": + # https://github.com/vllm-project/vllm/commit/7206ce4ce112ed117796a59045c968a6d353f691 + logger.warning("Reset prefix cache API is available only from vLLM 0.7.0!") + return + self.llm.llm_engine.reset_prefix_cache() + + def stop_remote_worker_execution_loop(self): + # Fix error for using 2 communication group + # https://github.com/vllm-project/vllm/commit/eb6d3c264d0cd8e44dec16bca7947fbe96415ce9#diff-e1ad69e38e033accddfa5480ec808c4740eb39244d1ef51cc3407e20dde8cfd4 + if self.__version__ > "0.4.2": + self.llm.llm_engine.model_executor.stop_remote_worker_execution_loop() + + +def create_vllm_engines( + num_engines: int, + tensor_parallel_size: int, + pretrain: str, + seed: int, + enable_prefix_caching: bool, + enforce_eager: bool, + max_model_len: int, +): + vllm_engines = [] + # RAY_EXPERIMENTAL_NOSET_*_VISIBLE_DEVICES will always be set in current context, + # So we need to get env variables from ray process to check if it is set. + noset_visible_devices = ray_noset_visible_devices(ray.get(get_all_env_variables.remote())) + for i in range(num_engines): + # When tensor_parallel_size=1 and RAY_EXPERIMENTAL_NOSET_*_VISIBLE_DEVICES is not set + # (vLLM mp backend will work smoothly only when *_VISIBLE_DEVICES is modified), + # vLLM init model in LLMEngine directly, assign 1 GPU for it. + num_gpus = int(tensor_parallel_size == 1 and not noset_visible_devices) + scheduling_strategy = None + + if tensor_parallel_size > 1 or noset_visible_devices: + bundles = [{"GPU": 1, "CPU": 1}] * tensor_parallel_size + pg = placement_group(bundles) + ray.get(pg.ready()) + + scheduling_strategy = PlacementGroupSchedulingStrategy( + placement_group=pg, placement_group_capture_child_tasks=True, placement_group_bundle_index=0 + ) + + vllm_engines.append( + LLMRayActor.options( + num_cpus=1, + num_gpus=num_gpus, + scheduling_strategy=scheduling_strategy, + ).remote( + pretrain, + noset_visible_devices=noset_visible_devices, + trust_remote_code=True, + tensor_parallel_size=tensor_parallel_size, + dtype="bfloat16", + seed=seed + i, + enable_prefix_caching=enable_prefix_caching, + enforce_eager=enforce_eager, + max_model_len=max_model_len, + ) + ) + + return vllm_engines + + +if __name__ == "__main__": + llm = LLMRayActor.remote("meta-llama/Llama-2-7b-chat-hf", tensor_parallel_size=4) + output = ray.get(llm.generate.remote("San Franciso is a")) + print(f"output: {output}") diff --git a/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/trainer/ray/vllm_worker_wrap.py b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/trainer/ray/vllm_worker_wrap.py new file mode 100644 index 0000000000..2f324793d0 --- /dev/null +++ b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/trainer/ray/vllm_worker_wrap.py @@ -0,0 +1,58 @@ +import torch +from vllm.worker.worker import Worker + +from openrlhf.utils.distributed_util import init_process_group +from openrlhf.utils.logging_utils import init_logger + +logger = init_logger(__name__) + + +class WorkerWrap(Worker): + def init_process_group(self, master_address, master_port, rank_offset, world_size, group_name, backend="nccl", use_ray=False): + """Init torch process group for model weights update""" + assert torch.distributed.is_initialized(), f"default torch process group must be initialized" + assert group_name != "", f"group name must not be empty" + + rank = torch.distributed.get_rank() + rank_offset + if use_ray: + import ray.util.collective as collective + collective.init_collective_group( + world_size=world_size, + rank=rank, + backend=backend, + group_name=group_name + ) + self._model_update_group = group_name + else: + self._model_update_group = init_process_group( + backend=backend, + init_method=f"tcp://{master_address}:{master_port}", + world_size=world_size, + rank=rank, + group_name=group_name, + ) + self._model_update_with_ray = use_ray + print( + f"init_process_group: master_address={master_address}, master_port={master_port}, ", + f"rank={rank}, world_size={world_size}, group_name={group_name}", + ) + + def update_weight(self, name, dtype, shape, empty_cache=False): + """Broadcast weight to all vllm workers from source rank 0 (actor model)""" + if torch.distributed.get_rank() == 0: + print(f"update weight: {name}, dtype: {dtype}, shape: {shape}") + + assert dtype == self.model_config.dtype, f"mismatch dtype: src {dtype}, dst {self.model_config.dtype}" + weight = torch.empty(shape, dtype=dtype, device="cuda") + if self._model_update_with_ray: + import ray.util.collective as collective + collective.broadcast(weight, 0, group_name=self._model_update_group) + else: + torch.distributed.broadcast(weight, 0, group=self._model_update_group) + + self.model_runner.model.load_weights(weights=[(name, weight)]) + + del weight + # TODO: should we empty cache if all weights have updated? + # if empty_cache: + # torch.cuda.empty_cache() diff --git a/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/trainer/rm_trainer.py b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/trainer/rm_trainer.py new file mode 100644 index 0000000000..63a39140e8 --- /dev/null +++ b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/trainer/rm_trainer.py @@ -0,0 +1,370 @@ +import os +from abc import ABC + +import torch +from torch.optim import Optimizer +from tqdm import tqdm + +from openrlhf.models import LogExpLoss, PairWiseLoss +from openrlhf.utils.distributed_sampler import DistributedSampler + + +class RewardModelTrainer(ABC): + """ + Trainer for training a reward model. + + Args: + model (torch.nn.Module): The model to be trained. + strategy (Strategy): The training strategy to apply. + optim (Optimizer): The optimizer to use during training. + train_dataloader (DataLoader): The dataloader for the training dataset. + eval_dataloader (DataLoader): The dataloader for the evaluation dataset. + scheduler (Scheduler): The learning rate scheduler for dynamic adjustments during training. + tokenizer (Tokenizer): The tokenizer for processing input text data. + max_norm (float, defaults to 0.5): Maximum gradient norm for gradient clipping. + max_epochs (int, defaults to 2): Maximum number of training epochs. + loss (str, defaults to "sigmoid"): The loss function to use during training, e.g., "sigmoid". + """ + + def __init__( + self, + model, + strategy, + optim: Optimizer, + train_dataloader, + eval_dataloader, + scheduler, + tokenizer, + max_norm=0.5, + max_epochs: int = 2, + loss="sigmoid", + ) -> None: + super().__init__() + self.strategy = strategy + self.epochs = max_epochs + self.max_norm = max_norm + self.model = model + self.train_dataloader = train_dataloader + self.eval_dataloader = eval_dataloader + self.scheduler = scheduler + self.optimizer = optim + self.tokenizer = tokenizer + self.args = strategy.args + + if loss == "sigmoid": + self.loss_fn = PairWiseLoss() + self.strategy.print("LogSigmoid Loss") + else: + self.loss_fn = LogExpLoss() + self.strategy.print("LogExp Loss") + + # Mixtral 8*7b + self.aux_loss = self.args.aux_loss_coef > 1e-8 + + # packing samples + self.packing_samples = strategy.args.packing_samples + + self.margin_loss = self.strategy.args.margin_loss + self.compute_fp32_loss = self.strategy.args.compute_fp32_loss + + # wandb/tensorboard setting + self._wandb = None + self._tensorboard = None + if self.strategy.args.use_wandb and self.strategy.is_rank_0(): + import wandb + + self._wandb = wandb + if not wandb.api.api_key: + wandb.login(key=strategy.args.use_wandb) + wandb.init( + entity=strategy.args.wandb_org, + project=strategy.args.wandb_project, + group=strategy.args.wandb_group, + name=strategy.args.wandb_run_name, + config=strategy.args.__dict__, + reinit=True, + ) + + wandb.define_metric("train/global_step") + wandb.define_metric("train/*", step_metric="train/global_step", step_sync=True) + wandb.define_metric("eval/global_step") + wandb.define_metric("eval/*", step_metric="eval/global_step", step_sync=True) + + # Initialize TensorBoard writer if wandb is not available + if self.strategy.args.use_tensorboard and self._wandb is None and self.strategy.is_rank_0(): + from torch.utils.tensorboard import SummaryWriter + + os.makedirs(self.strategy.args.use_tensorboard, exist_ok=True) + log_dir = os.path.join(self.strategy.args.use_tensorboard, strategy.args.wandb_run_name) + self._tensorboard = SummaryWriter(log_dir=log_dir) + + def fit(self, args, consumed_samples=0, num_update_steps_per_epoch=None): + # get eval and save steps + if args.eval_steps == -1: + args.eval_steps = num_update_steps_per_epoch # Evaluate once per epoch + if args.save_steps == -1: + args.save_steps = float("inf") # do not save ckpt + + # Restore step and start_epoch + step = consumed_samples // args.train_batch_size * self.strategy.accumulated_gradient + 1 + start_epoch = consumed_samples // args.train_batch_size // num_update_steps_per_epoch + consumed_samples = consumed_samples % (num_update_steps_per_epoch * args.train_batch_size) + + epoch_bar = tqdm(range(start_epoch, self.epochs), desc="Train epoch", disable=not self.strategy.is_rank_0()) + acc_sum = 0 + loss_sum = 0 + for epoch in range(start_epoch, self.epochs): + if isinstance(self.train_dataloader.sampler, DistributedSampler): + self.train_dataloader.sampler.set_epoch( + epoch, consumed_samples=0 if epoch > start_epoch else consumed_samples + ) + + # train + step_bar = tqdm( + range(self.train_dataloader.__len__()), + desc="Train step of epoch %d" % epoch, + disable=not self.strategy.is_rank_0(), + ) + + self.model.train() + for data in self.train_dataloader: + if not self.packing_samples: + chosen_ids, c_mask, reject_ids, r_mask, margin = data + chosen_ids = chosen_ids.squeeze(1).to(torch.cuda.current_device()) + c_mask = c_mask.squeeze(1).to(torch.cuda.current_device()) + reject_ids = reject_ids.squeeze(1).to(torch.cuda.current_device()) + r_mask = r_mask.squeeze(1).to(torch.cuda.current_device()) + + chosen_reward, reject_reward, aux_loss = self.concatenated_forward( + self.model, chosen_ids, c_mask, reject_ids, r_mask + ) + else: + packed_input_ids, packed_attention_masks, packed_seq_lens, margin = data + packed_input_ids, packed_attention_masks = packed_input_ids.to( + torch.cuda.current_device() + ), packed_attention_masks.to(torch.cuda.current_device()) + + chosen_reward, reject_reward, aux_loss = self.packed_samples_forward( + self.model, packed_input_ids, packed_attention_masks, packed_seq_lens + ) + + if self.margin_loss: + margin = torch.tensor(margin).to(torch.cuda.current_device()) + else: + margin = None + + # loss function + if self.compute_fp32_loss: + chosen_reward = chosen_reward.float() + reject_reward = reject_reward.float() + + preference_loss = self.loss_fn(chosen_reward, reject_reward, margin) + # mixtral + if not self.aux_loss: + aux_loss = 0 + + loss = preference_loss + aux_loss * self.args.aux_loss_coef + self.strategy.backward(loss, self.model, self.optimizer) + self.strategy.optimizer_step(self.optimizer, self.model, self.scheduler) + + acc = (chosen_reward > reject_reward).float().mean().item() + acc_sum += acc + loss_sum += preference_loss.item() + # optional rm info + logs_dict = { + "loss": preference_loss.item(), + "acc": acc, + "chosen_reward": chosen_reward.mean().item(), + "reject_reward": reject_reward.mean().item(), + "lr": self.scheduler.get_last_lr()[0], + } + if self.aux_loss: + logs_dict["aux_loss"] = aux_loss.item() + + # step bar + logs_dict = self.strategy.all_reduce(logs_dict) + step_bar.set_postfix(logs_dict) + step_bar.update() + + # logs/checkpoints/evaluation + if step % self.strategy.accumulated_gradient == 0: + logs_dict["loss_mean"] = loss_sum / self.strategy.accumulated_gradient + logs_dict["acc_mean"] = acc_sum / self.strategy.accumulated_gradient + loss_sum = 0 + acc_sum = 0 + global_step = step // self.strategy.accumulated_gradient + client_states = {"consumed_samples": global_step * args.train_batch_size} + self.save_logs_and_checkpoints(args, global_step, step_bar, logs_dict, client_states) + + step += 1 + epoch_bar.update() + + if self._wandb is not None and self.strategy.is_rank_0(): + self._wandb.finish() + if self._tensorboard is not None and self.strategy.is_rank_0(): + self._tensorboard.close() + + # logs/checkpoints/evaluate + def save_logs_and_checkpoints(self, args, global_step, step_bar, logs_dict={}, client_states={}): + if global_step % args.logging_steps == 0: + # wandb + if self._wandb is not None and self.strategy.is_rank_0(): + logs = {"train/%s" % k: v for k, v in {**logs_dict, "global_step": global_step}.items()} + self._wandb.log(logs) + # TensorBoard + elif self._tensorboard is not None and self.strategy.is_rank_0(): + for k, v in logs_dict.items(): + self._tensorboard.add_scalar(f"train/{k}", v, global_step) + + # eval + if global_step % args.eval_steps == 0: + # do eval when len(dataloader) > 0, avoid zero division in eval. + if len(self.eval_dataloader) > 0: + self.evaluate(self.eval_dataloader, global_step) + # save ckpt + # TODO: save best model on dev, use loss/perplexity on whole dev dataset as metric + if global_step % args.save_steps == 0: + tag = f"global_step{global_step}" + self.strategy.save_ckpt( + self.model, args.ckpt_path, tag, args.max_ckpt_num, args.max_ckpt_mem, client_states + ) + + def evaluate(self, eval_dataloader, steps=0): + step_bar = tqdm( + range(eval_dataloader.__len__()), + desc="Eval stage of steps %d" % steps, + disable=not self.strategy.is_rank_0(), + ) + self.model.eval() + with torch.no_grad(): + acc = 0 + rewards = [] + loss_sum = 0 + for data in eval_dataloader: + if not self.packing_samples: + chosen_ids, c_mask, reject_ids, r_mask, margin = data + chosen_ids = chosen_ids.squeeze(1).to(torch.cuda.current_device()) + c_mask = c_mask.squeeze(1).to(torch.cuda.current_device()) + reject_ids = reject_ids.squeeze(1).to(torch.cuda.current_device()) + r_mask = r_mask.squeeze(1).to(torch.cuda.current_device()) + + chosen_reward, reject_reward, _ = self.concatenated_forward( + self.model, chosen_ids, c_mask, reject_ids, r_mask + ) + else: + packed_input_ids, packed_attention_masks, packed_seq_lens, margin = data + packed_input_ids, packed_attention_masks = packed_input_ids.to( + torch.cuda.current_device() + ), packed_attention_masks.to(torch.cuda.current_device()) + + chosen_reward, reject_reward, _ = self.packed_samples_forward( + self.model, packed_input_ids, packed_attention_masks, packed_seq_lens + ) + + if self.margin_loss: + margin = torch.tensor(margin).to(torch.cuda.current_device()) + else: + margin = None + + loss = self.loss_fn(chosen_reward, reject_reward, margin) + + rewards += [chosen_reward.flatten(), reject_reward.flatten()] + acc += (chosen_reward > reject_reward).float().mean().item() + loss_sum += loss.item() + step_bar.update() + + acc_mean = acc / self.eval_dataloader.__len__() + loss_mean = loss_sum / self.eval_dataloader.__len__() + + rewards = torch.cat(rewards).float() + rewards = self.strategy.all_gather(rewards) + reward_mean = torch.mean(rewards) + reward_std = torch.std(rewards).clamp(min=1e-8) + + # save mean std + self.strategy.print("Set reward mean std") + unwrap_model = self.strategy._unwrap_model(self.model) + unwrap_model.config.mean = reward_mean.item() + unwrap_model.config.std = reward_std.item() + + bar_dict = { + "eval_loss": loss_mean, + "acc_mean": acc_mean, + "reward_mean": reward_mean.item(), + "reward_std": reward_std.item(), + } + logs = self.strategy.all_reduce(bar_dict) + step_bar.set_postfix(logs) + + histgram = torch.histogram(rewards.cpu(), bins=10, range=(-10, 10), density=True) * 2 + self.strategy.print("histgram") + self.strategy.print(histgram) + + if self.strategy.is_rank_0(): + if self._wandb is not None: + logs = {"eval/%s" % k: v for k, v in {**logs, "global_step": steps}.items()} + self._wandb.log(logs) + elif self._tensorboard is not None: + for k, v in logs.items(): + self._tensorboard.add_scalar(f"eval/{k}", v, steps) + self.model.train() # reset model state + + def concatenated_forward(self, model, chosen_ids, c_mask, reject_ids, r_mask): + """Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together. + + We do this to avoid doing two forward passes, because it's faster for FSDP. + """ + input_ids, att_masks = self.concatenated_inputs(chosen_ids, c_mask, reject_ids, r_mask) + all_values, output = model(input_ids, attention_mask=att_masks, return_output=True) + chosen_rewards = all_values[: chosen_ids.shape[0]] + rejected_rewards = all_values[chosen_ids.shape[0] :] + aux_loss = output.aux_loss if "aux_loss" in output else [] + return chosen_rewards, rejected_rewards, aux_loss + + def concatenated_inputs(self, chosen_ids, c_mask, reject_ids, r_mask): + """Concatenate the chosen and rejected inputs into a single tensor. + + Args: + batch: A batch of data. Must contain the keys 'chosen_input_ids' and 'rejected_input_ids', which are tensors of shape (batch_size, sequence_length). + + Returns: + A dictionary containing the concatenated inputs under the key 'concatenated_input_ids'. + """ + + def pad_to_length(tensor, length, pad_value, dim=-1): + if tensor.size(dim) >= length: + return tensor + else: + pad_size = list(tensor.shape) + pad_size[dim] = length - tensor.size(dim) + # left pad + return torch.cat( + [pad_value * torch.ones(*pad_size, dtype=tensor.dtype, device=tensor.device), tensor], dim=dim + ) + + max_length = max(chosen_ids.shape[1], reject_ids.shape[1]) + inputs_ids = torch.cat( + ( + pad_to_length(chosen_ids, max_length, self.tokenizer.pad_token_id), + pad_to_length(reject_ids, max_length, self.tokenizer.pad_token_id), + ), + dim=0, + ) + max_length = max(c_mask.shape[1], r_mask.shape[1]) + att_masks = torch.cat((pad_to_length(c_mask, max_length, 0), pad_to_length(r_mask, max_length, 0)), dim=0) + return inputs_ids, att_masks + + def packed_samples_forward(self, model, packed_input_ids, packed_attention_masks, packed_seq_lens): + all_values, output = model( + packed_input_ids, + attention_mask=packed_attention_masks, + return_output=True, + ring_attn_group=self.strategy.ring_attn_group, + packed_seq_lens=packed_seq_lens, + ) + half_len = len(packed_seq_lens) // 2 + chosen_rewards = all_values[:half_len] + rejected_rewards = all_values[half_len:] + aux_loss = output.aux_loss if "aux_loss" in output else [] + + return chosen_rewards, rejected_rewards, aux_loss diff --git a/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/trainer/sft_trainer.py b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/trainer/sft_trainer.py new file mode 100644 index 0000000000..fa92452d50 --- /dev/null +++ b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/trainer/sft_trainer.py @@ -0,0 +1,318 @@ +import os +from abc import ABC + +import torch +from torch.optim import Optimizer +from tqdm import tqdm + +from openrlhf.models import GPTLMLoss +from openrlhf.utils.distributed_sampler import DistributedSampler + + +class SFTTrainer(ABC): + """ + Trainer for supervised fine-tuning (SFT). + + Args: + model (torch.nn.Module): The model to be trained. + strategy (Strategy): The training strategy to be applied. + optim (Optimizer): The optimizer for model training. + train_dataloader (DataLoader): The dataloader for the training dataset. + eval_dataloader (DataLoader): The dataloader for the evaluation dataset. + scheduler (Scheduler): The learning rate scheduler to adjust training rates. + max_norm (float, defaults to 1): Maximum gradient norm for clipping to prevent exploding gradients. + pretrain_mode (bool, defaults to False): Flag to indicate if the trainer is in pre-training mode. + batch_size (int, defaults to 1): Batch size for training. + max_epochs (int, defaults to 2): The maximum number of training epochs. + tokenizer (Tokenizer, optional): The tokenizer for processing input data. + save_hf_ckpt (bool): Whether to save huggingface-format model weight. + disable_ds_ckpt (bool): Whether not to save deepspeed-format model weight. (Deepspeed model weight is used for training recovery) + """ + + def __init__( + self, + model, + strategy, + optim: Optimizer, + train_dataloader, + eval_dataloader, + scheduler, + max_norm: float = 1, + pretrain_mode: bool = False, + batch_size: int = 1, + max_epochs: int = 2, + tokenizer=None, + save_hf_ckpt: bool = False, + disable_ds_ckpt: bool = False, + ) -> None: + super().__init__() + self.strategy = strategy + self.epochs = max_epochs + self.batch_size = batch_size + self.max_norm = max_norm + self.train_dataloader = train_dataloader + self.eval_dataloader = eval_dataloader + self.scheduler = scheduler + self.pretrain_mode = pretrain_mode + self.model = model + self.tokenizer = tokenizer + self.optimizer = optim + self.args = strategy.args + self.save_hf_ckpt = save_hf_ckpt + self.disable_ds_ckpt = disable_ds_ckpt + + self.loss_fn = GPTLMLoss(ring_attn_group=self.strategy.ring_attn_group) + + # Mixtral 8*7b + self.aux_loss = self.args.aux_loss_coef > 1e-8 + + # packing samples + self.packing_samples = strategy.args.packing_samples + + # wandb/tensorboard setting + self._wandb = None + self._tensorboard = None + if self.strategy.args.use_wandb and self.strategy.is_rank_0(): + import wandb + + self._wandb = wandb + if not wandb.api.api_key: + wandb.login(key=strategy.args.use_wandb) + wandb.init( + entity=strategy.args.wandb_org, + project=strategy.args.wandb_project, + group=strategy.args.wandb_group, + name=strategy.args.wandb_run_name, + config=strategy.args.__dict__, + reinit=True, + ) + + wandb.define_metric("train/global_step") + wandb.define_metric("train/*", step_metric="train/global_step", step_sync=True) + wandb.define_metric("eval/global_step") + wandb.define_metric("eval/*", step_metric="eval/global_step", step_sync=True) + + # Initialize TensorBoard writer if wandb is not available + if self.strategy.args.use_tensorboard and self._wandb is None and self.strategy.is_rank_0(): + from torch.utils.tensorboard import SummaryWriter + + os.makedirs(self.strategy.args.use_tensorboard, exist_ok=True) + log_dir = os.path.join(self.strategy.args.use_tensorboard, strategy.args.wandb_run_name) + self._tensorboard = SummaryWriter(log_dir=log_dir) + + def fit(self, args, consumed_samples=0, num_update_steps_per_epoch=None): + # get eval and save steps + if args.eval_steps == -1: + args.eval_steps = num_update_steps_per_epoch # Evaluate once per epoch + if args.save_steps == -1: + args.save_steps = float("inf") # do not save ckpt + + # Restore step and start_epoch + step = consumed_samples // args.train_batch_size * self.strategy.accumulated_gradient + 1 + start_epoch = consumed_samples // args.train_batch_size // num_update_steps_per_epoch + consumed_samples = consumed_samples % (num_update_steps_per_epoch * args.train_batch_size) + + epoch_bar = tqdm( + range(start_epoch, self.epochs), + desc="Train epoch", + disable=not self.strategy.is_rank_0(), + ) + loss_sum = 0 + for epoch in range(start_epoch, self.epochs): + if isinstance(self.train_dataloader.sampler, DistributedSampler): + self.train_dataloader.sampler.set_epoch( + epoch, consumed_samples=0 if epoch > start_epoch else consumed_samples + ) + + step_bar = tqdm( + range(self.train_dataloader.__len__()), + desc="Train step of epoch %d" % epoch, + disable=not self.strategy.is_rank_0(), + ) + + # train + self.model.train() + for prompt_id_lens, inputs, attention_masks, infos in self.train_dataloader: + if self.packing_samples: + inputs = inputs.to(torch.cuda.current_device()) + attention_mask = attention_masks.to(torch.cuda.current_device()) + else: + inputs = inputs.to(torch.cuda.current_device()).squeeze(1) + attention_mask = attention_masks.to(torch.cuda.current_device()).squeeze(1) + + if self.strategy.ring_attn_group is None: + output = self.model(inputs, attention_mask=attention_mask, return_output=True) + else: + output = self.model( + inputs, + attention_mask=attention_mask, + return_output=True, + ring_attn_group=self.strategy.ring_attn_group, + packed_seq_lens=infos["input_length"], + ) + + # loss function + labels = torch.where( + attention_mask.bool(), + inputs, + self.loss_fn.IGNORE_INDEX, + ) + # mixtral + if self.aux_loss: + aux_loss = output.aux_loss + else: + aux_loss = 0 + + if not self.pretrain_mode: + if self.packing_samples: + # As response_ranges need to constrain the dataset organization strictly, we handle multiturn feature separately. + if infos["response_ranges"]: + dump_labels = torch.full(labels.size(), self.loss_fn.IGNORE_INDEX).to(labels.device) + for response_ranges in infos["response_ranges"]: + for response_range in response_ranges: + dump_labels[0][response_range[0]: response_range[1]] = labels[0][response_range[0]: response_range[1]] + labels = dump_labels + else: + index = 0 + for input_length, source_len in zip(infos["input_length"], prompt_id_lens): + labels[0][index : index + source_len] = self.loss_fn.IGNORE_INDEX + index += input_length + else: + for label, source_len in zip(labels, prompt_id_lens): + label[:source_len] = self.loss_fn.IGNORE_INDEX + + gpt_loss = self.loss_fn(output.logits, labels) + loss = gpt_loss + aux_loss * self.args.aux_loss_coef + self.strategy.backward(loss, self.model, self.optimizer) + self.strategy.optimizer_step(self.optimizer, self.model, self.scheduler) + + loss_sum += gpt_loss.item() + logs_dict = { + "gpt_loss": gpt_loss.item(), + "lr": self.scheduler.get_last_lr()[0], + } + if self.aux_loss: + logs_dict["aux_loss"] = aux_loss.item() + # step bar + logs_dict = self.strategy.all_reduce(logs_dict) + step_bar.set_postfix(logs_dict) + step_bar.update() + + # logs/checkpoints/evaluation + if step % self.strategy.accumulated_gradient == 0: + logs_dict["loss_mean"] = loss_sum / self.strategy.accumulated_gradient + loss_sum = 0 + global_step = step // self.strategy.accumulated_gradient + client_states = {"consumed_samples": global_step * args.train_batch_size} + self.save_logs_and_checkpoints(args, global_step, step_bar, logs_dict, client_states) + + step += 1 + + epoch_bar.update() + + if self._wandb is not None and self.strategy.is_rank_0(): + self._wandb.finish() + if self._tensorboard is not None and self.strategy.is_rank_0(): + self._tensorboard.close() + + # logs/checkpoints/evaluation + def save_logs_and_checkpoints(self, args, global_step, step_bar, logs_dict={}, client_states={}): + if global_step % args.logging_steps == 0: + # wandb + if self._wandb is not None and self.strategy.is_rank_0(): + logs = {"train/%s" % k: v for k, v in {**logs_dict, "global_step": global_step}.items()} + self._wandb.log(logs) + # TensorBoard + elif self._tensorboard is not None and self.strategy.is_rank_0(): + for k, v in logs_dict.items(): + self._tensorboard.add_scalar(f"train/{k}", v, global_step) + + # eval + if global_step % args.eval_steps == 0: + # do eval when len(dataloader) > 0, avoid zero division in eval. + if len(self.eval_dataloader) > 0: + self.evaluate(self.eval_dataloader, global_step) + + # save ckpt + # TODO: save best model on dev, use loss/perplexity on whole dev dataset as metric + if global_step % args.save_steps == 0: + tag = f"global_step{global_step}" + if not self.disable_ds_ckpt: + self.strategy.save_ckpt( + self.model.model, args.ckpt_path, tag, args.max_ckpt_num, args.max_ckpt_mem, client_states + ) + if self.save_hf_ckpt: + save_path = os.path.join(args.ckpt_path, f"{tag}_hf") + self.strategy.save_model(self.model, self.tokenizer, save_path) + + def evaluate(self, eval_dataloader, steps=0): + times = 0 + self.model.eval() + with torch.no_grad(): + loss_sum = 0 + step_bar = tqdm( + range(eval_dataloader.__len__()), + desc="Eval stage of steps %d" % steps, + disable=not self.strategy.is_rank_0(), + ) + + for prompt_id_lens, inputs, attention_masks, infos in eval_dataloader: + if self.packing_samples: + inputs = inputs.to(torch.cuda.current_device()) + attention_mask = attention_masks.to(torch.cuda.current_device()) + else: + inputs = inputs.to(torch.cuda.current_device()).squeeze(1) + attention_mask = attention_masks.to(torch.cuda.current_device()).squeeze(1) + + if self.strategy.ring_attn_group is None: + output = self.model(inputs, attention_mask=attention_mask, return_output=True) + else: + output = self.model( + inputs, + attention_mask=attention_mask, + return_output=True, + ring_attn_group=self.strategy.ring_attn_group, + packed_seq_lens=infos["input_length"], + ) + + # loss function + labels = torch.where( + attention_mask.bool(), + inputs, + self.loss_fn.IGNORE_INDEX, + ) + + if not self.pretrain_mode: + if self.packing_samples: + if infos["response_ranges"]: + dump_labels = torch.full(labels.size(), self.loss_fn.IGNORE_INDEX).to(labels.device) + for response_ranges in infos["response_ranges"]: + for response_range in response_ranges: + dump_labels[0][response_range[0]: response_range[1]] = labels[0][response_range[0]: response_range[1]] + labels = dump_labels + else: + index = 0 + for input_length, source_len in zip(infos["input_length"], prompt_id_lens): + labels[0][index : index + source_len] = self.loss_fn.IGNORE_INDEX + index += input_length + else: + for label, source_len in zip(labels, prompt_id_lens): + label[:source_len] = self.loss_fn.IGNORE_INDEX + + loss = self.loss_fn(output.logits, labels) + + times += 1 + loss_sum += loss.item() + bar_dict = {"eval gpt_loss": loss_sum / times} + step_bar.update() + logs = self.strategy.all_reduce(bar_dict) + step_bar.set_postfix(logs) + + if self.strategy.is_rank_0(): + if self._wandb is not None: + logs = {"eval/%s" % k: v for k, v in {**logs, "global_step": steps}.items()} + self._wandb.log(logs) + elif self._tensorboard is not None: + for k, v in logs.items(): + self._tensorboard.add_scalar(f"eval/{k}", v, steps) + self.model.train() # reset model state diff --git a/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/utils/__init__.py b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/utils/__init__.py new file mode 100644 index 0000000000..08ab0a9ba9 --- /dev/null +++ b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/utils/__init__.py @@ -0,0 +1,10 @@ +from .processor import get_processor, reward_normalization +from .utils import blending_datasets, get_strategy, get_tokenizer + +__all__ = [ + "get_processor", + "reward_normalization", + "blending_datasets", + "get_strategy", + "get_tokenizer", +] diff --git a/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/utils/deepspeed/__init__.py b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/utils/deepspeed/__init__.py new file mode 100644 index 0000000000..473108291c --- /dev/null +++ b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/utils/deepspeed/__init__.py @@ -0,0 +1,5 @@ +from .deepspeed import DeepspeedStrategy + +__all__ = [ + "DeepspeedStrategy", +] diff --git a/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/utils/deepspeed/deepspeed.py b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/utils/deepspeed/deepspeed.py new file mode 100644 index 0000000000..ea01203e01 --- /dev/null +++ b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/utils/deepspeed/deepspeed.py @@ -0,0 +1,467 @@ +import os +import random +import shutil +from abc import ABC +from collections import defaultdict +from datetime import timedelta +from typing import List, Tuple, Union + +import deepspeed +import numpy as np +import torch +import torch.nn as nn +import torch.optim as optim +from deepspeed.ops.adam import DeepSpeedCPUAdam, FusedAdam +from peft import PeftModel, get_peft_model_state_dict +from torch import distributed as dist +from torch.optim import Optimizer +from torch.utils.data import DataLoader + +from openrlhf.models import Actor +from openrlhf.models.ring_attn_utils import get_ring_attn_group, set_ring_attn_group +from openrlhf.utils.distributed_sampler import DistributedSampler + +from .deepspeed_utils import ( + _z3_params_to_fetch, + get_eval_ds_config, + get_optimizer_grouped_parameters, + get_train_ds_config, +) + +ModelOptimPair = Tuple[nn.Module, Optimizer] +ModelOrModelOptimPair = Union[nn.Module, ModelOptimPair] + + +class DeepspeedStrategy(ABC): + """ + The strategy for training with Accelerator. + """ + + def __init__( + self, + seed: int = 42, + max_norm: float = 0.0, + micro_train_batch_size=1, + train_batch_size=1, + zero_stage=2, + bf16=True, + args=None, + ) -> None: + super().__init__() + + self.args = args + self.stage = zero_stage + self.train_batch_size = train_batch_size + self.micro_train_batch_size = micro_train_batch_size + self.bf16 = bf16 + self.seed = seed + self.max_norm = max_norm + self.adam_offload = getattr(args, "adam_offload", False) + self.zpg = getattr(args, "zpg", 1) + self.grad_accum_dtype = getattr(args, "grad_accum_dtype", None) + # overlap_comm + self.overlap_comm = getattr(args, "overlap_comm", False) + + self.is_rlhf = False + self.time_steps = defaultdict(int) + + def set_seed(self, seed: int) -> None: + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + def setup_distributed(self, timeout=timedelta(minutes=60)) -> None: + self.set_seed(self.seed) + + if self.args.local_rank == -1 and "LOCAL_RANK" in os.environ: # for slurm + self.args.local_rank = int(os.environ["LOCAL_RANK"]) + + if self.args.local_rank != -1: + torch.cuda.set_device(self.args.local_rank) + # Initializes the distributed backend which will take care of sychronizing nodes/GPUs + deepspeed.init_distributed(timeout=timeout) + self.setup_ring_attn() + self.world_size = dist.get_world_size() + self.accumulated_gradient = ( + self.train_batch_size * self.ring_attn_size // self.micro_train_batch_size // self.world_size + ) + + def setup_ring_attn(self): + self.ring_attn_size = getattr(self.args, "ring_attn_size", 1) + if self.ring_attn_size == 1: + self.ring_attn_rank = 0 + return + + ring_head_stride = getattr(self.args, "ring_head_stride", 1) + for i in range(dist.get_world_size() // self.ring_attn_size): + ring_attn_ranks = list( + range( + i * self.ring_attn_size, + (i + 1) * self.ring_attn_size, + ) + ) + group = dist.new_group(ranks=ring_attn_ranks, backend="nccl") + if dist.get_rank() in ring_attn_ranks: + set_ring_attn_group(group) + self.ring_attn_rank = dist.get_rank(group=group) + + from ring_flash_attn import substitute_hf_flash_attn + + substitute_hf_flash_attn(self.ring_attn_group, ring_head_stride) + + @property + def ring_attn_group(self): + return get_ring_attn_group() + + def create_optimizer(self, model, **kwargs) -> Optimizer: + if isinstance(model, Actor): + model = model.model + # Optimizer + AdamOptimizer = DeepSpeedCPUAdam if self.adam_offload else FusedAdam + optim_params = get_optimizer_grouped_parameters(model, kwargs["weight_decay"]) + optim = AdamOptimizer(optim_params, **kwargs) + return optim + + def backward(self, loss: torch.Tensor, model: nn.Module, optimizer: optim.Optimizer, **kwargs) -> None: + if isinstance(model, Actor): + model = model.model + model.backward(loss) + + def optimizer_step( + self, + optimizer: optim.Optimizer, + model: nn.Module, + scheduler, + name="model", + **kwargs, + ) -> None: + if isinstance(model, Actor): + model = model.model + model.step() + + def setup_dataloader( + self, + replay_buffer, + batch_size: int, + pin_memory: bool = False, + shuffle=True, + collate_fn=None, + drop_last=True, + sampler=None, + consumed_samples=0, + ): + # DDP only mode, replay buffers on each rank are different. + if sampler is None: + num_replicas = dist.get_world_size() // self.ring_attn_size + rank = dist.get_rank() // self.ring_attn_size + sampler = DistributedSampler( + replay_buffer, + num_replicas=num_replicas, + rank=rank, + shuffle=shuffle, + seed=self.seed, + drop_last=drop_last, + consumed_samples=consumed_samples, + ) + + return DataLoader( + replay_buffer, + batch_size=batch_size, + sampler=sampler, + drop_last=drop_last, + collate_fn=collate_fn, + pin_memory=pin_memory, + ) + + def _unwrap_model(self, model) -> nn.Module: + if isinstance(model, Actor): + return self._unwrap_model(model.model) + elif hasattr(model, "module"): + return model.module + else: + return model + + def prepare( + self, *models_or_model_optim_pairs: ModelOrModelOptimPair, is_rlhf=False + ) -> Union[List[ModelOrModelOptimPair], ModelOrModelOptimPair]: + ret = [] + self.is_rlhf = is_rlhf + for arg in models_or_model_optim_pairs: + if isinstance(arg, tuple): + assert len(arg) == 3, f'Expect (model, optimizer, scheduler) pair, got a tuple with size "{len(arg)}"' + if arg[0] is not None: + ret.append(self._ds_init_train_model(*arg)) + else: + ret.append((None, None, None)) + else: + ret.append(self._ds_init_eval_model(arg)) + + return ret[0] if len(ret) == 1 else ret + + def _ds_init_train_model(self, model, optim, scheduler): + is_actor = isinstance(model, Actor) + ds_config = self.get_ds_train_config(is_actor) + + engine, optim, _, scheduler = deepspeed.initialize( + model=model.model if is_actor else model, + optimizer=optim, + lr_scheduler=scheduler, + config=ds_config, + args={"local_rank": self.args.local_rank}, + dist_init_required=True, + ) + if is_actor: + model.model = engine + else: + model = engine + + return model, optim, scheduler + + def get_ds_train_config(self, is_actor): + # DS Config + ds_config = get_train_ds_config( + offload=False, + adam_offload=self.adam_offload, + stage=self.stage, + bf16=self.bf16, + max_norm=self.max_norm, + zpg=self.zpg, + grad_accum_dtype=self.grad_accum_dtype, + overlap_comm=self.overlap_comm, + ) + + ds_config["train_micro_batch_size_per_gpu"] = self.micro_train_batch_size + train_batch_size = self.train_batch_size + # corner case for ptx loss (backward twice) + if self.is_rlhf and is_actor and self.args.pretrain_data is not None: + train_batch_size *= 2 + ds_config["train_batch_size"] = train_batch_size * self.ring_attn_size + + return ds_config + + def _ds_init_eval_model(self, model): + if not model: + return model + is_actor = isinstance(model, Actor) + ds_config = self.get_ds_eval_config(offload=getattr(model, "_offload", False)) + + engine, *_ = deepspeed.initialize( + model=model.model if is_actor else model, + args={"local_rank": self.args.local_rank}, + config=ds_config, + dist_init_required=True, + ) + if is_actor: + model.model = engine + else: + model = engine + return model + + def get_ds_eval_config(self, offload=False): + # DS Config + ds_config = get_eval_ds_config(offload=offload, stage=self.stage if self.stage == 3 else 0, bf16=self.bf16) + ds_config["train_micro_batch_size_per_gpu"] = self.micro_train_batch_size + ds_config["train_batch_size"] = self.train_batch_size * self.ring_attn_size + + return ds_config + + def moving_average(self, model, model_ema, beta=0.992, device="cpu"): + self.time_steps["ema"] += 1 + if self.time_steps["ema"] % self.accumulated_gradient == 0: + with torch.no_grad(): + for param, param_ema in zip(model.parameters(), model_ema.parameters()): + if param.requires_grad: + if self.stage != 3: + data = param.data.to(device) + param_ema.data.copy_((1 - beta) * data + beta * param_ema.data) + else: + # TODO: use prefiltering for efficiency + params_to_fetch = _z3_params_to_fetch([param, param_ema]) + with deepspeed.zero.GatheredParameters(params_to_fetch, enabled=len(params_to_fetch) > 0): + data = param.data.to(device) + param_ema.data.copy_((1 - beta) * data + beta * param_ema.data) + + def load_model( + self, + model: nn.Module, + path: str, + map_location="cpu", + strict: bool = False, + key_replace_fn=None, + ) -> None: + unwrapped_model = self._unwrap_model(model) + state_dict = torch.load(path, map_location=map_location) + if key_replace_fn: + state_dict = key_replace_fn(state_dict) + unwrapped_model.load_state_dict(state_dict, strict=strict) + + def save_model(self, model: nn.Module, tokenizer, output_dir, **kwargs) -> None: + if self.is_rank_0(): + os.makedirs(output_dir, exist_ok=True) + + # save model weights for ZeRO2/3 + model_to_save = self._unwrap_model(model) + + # gather parameters + output_state_dict = {} + for k, v in model_to_save.named_parameters(): + # only gather z3 params + params_to_fetch = _z3_params_to_fetch([v]) + with deepspeed.zero.GatheredParameters(params_to_fetch, enabled=len(params_to_fetch) > 0): + vv = v.data.cpu() + if self.is_rank_0(): + output_state_dict[k] = vv + + if self.is_rank_0(): + state_dict = model_to_save.state_dict() + + # copy named_buffers with `persistent=True` + for k, v in model_to_save.named_buffers(): + if k not in state_dict: + continue + vv = v.data.cpu() + output_state_dict[k] = vv + + state_dict_keys = set(state_dict.keys()) + output_state_dict_keys = set(output_state_dict.keys()) + + # corner case for tie_word_embeddings, such as Qwen2-0.5B + if getattr(model_to_save.config, "tie_word_embeddings", False) and "lm_head.weight" in state_dict_keys: + state_dict_keys.remove("lm_head.weight") + + assert state_dict_keys.issubset( + output_state_dict_keys + ), f"mismatch keys {output_state_dict_keys.symmetric_difference(state_dict_keys)}" + + # only save peft weights https://github.com/microsoft/DeepSpeed/issues/4295 + if isinstance(model_to_save, PeftModel): + model_to_save.save_pretrained(output_dir, **kwargs) + if self.stage == 3: + torch.save( + get_peft_model_state_dict(model_to_save, output_state_dict), + os.path.join(output_dir, "adapter_model.bin"), + ) + filename = os.path.join(output_dir, "adapter_model.safetensors") + if os.path.exists(filename): + os.remove(filename) + else: + # save model + model_to_save.save_pretrained(output_dir, state_dict=output_state_dict, **kwargs) + + # save config + output_config_file = os.path.join(output_dir, "config.json") + model_to_save.config.to_json_file(output_config_file) + # save tokenizer + tokenizer.save_pretrained(output_dir) + + # for models not in AutoModel, copy python module files + train_from_model_path = model_to_save.config._name_or_path + if os.path.exists(train_from_model_path): + for filename in os.listdir(train_from_model_path): + if filename.endswith(".py"): + shutil.copy(os.path.join(train_from_model_path, filename), os.path.join(output_dir, filename)) + + def all_reduce(self, data, op="mean"): + assert op in ("mean", "max", "sum") + if isinstance(data, dict): + ret = {} + for k, v in data.items(): + ret[k] = self.all_reduce(v, op) + return ret + else: + is_tensor = True + if not isinstance(data, torch.Tensor): + data = torch.Tensor([data]) + is_tensor = False + is_cpu_tensor = data.device.type == "cpu" + + if is_cpu_tensor: + data = data.to(torch.cuda.current_device()) + if op == "mean": + data /= self.world_size + dist.all_reduce(data, op=dist.ReduceOp.MAX if op == "max" else dist.ReduceOp.SUM) + if is_cpu_tensor: + data = data.cpu() + return data.item() if not is_tensor else data + + def all_gather(self, data): + if isinstance(data, dict): + ret = {} + for k, v in data.items(): + ret[k] = self.all_gather(v) + return ret + else: + if not isinstance(data, torch.Tensor): + data = torch.Tensor([data]) + is_cpu_tensor = data.device.type == "cpu" + + ret = [torch.zeros_like(data).to(torch.cuda.current_device()) for _ in range(self.world_size)] + dist.all_gather(ret, data.to(torch.cuda.current_device())) + return torch.cat(ret).cpu() if is_cpu_tensor else torch.cat(ret) + + def print(self, *msg): + if self.is_rank_0(): + print(*msg) + + def is_rank_0(self) -> bool: + return dist.get_rank() == 0 + + def get_rank(self) -> int: + return dist.get_rank() + + def save_ckpt(self, model, save_dir, tag=None, max_num=3, max_mem=1000, client_state={}, save_latest=True): + assert isinstance(model, deepspeed.DeepSpeedEngine) + if self.is_rank_0(): + os.makedirs(save_dir, exist_ok=True) + MAX_SIZE = max_mem * 1024**3 # Convert GB to bytes + + while True: + subdirs = sorted( + [ + (os.path.join(save_dir, d), os.path.getmtime(os.path.join(save_dir, d))) + for d in os.listdir(save_dir) + if os.path.isdir(os.path.join(save_dir, d)) + ], + key=lambda x: x[1], + ) + total_size = sum( + os.path.getsize(os.path.join(dirpath, f)) + for subdir, _ in subdirs + for dirpath, _, filenames in os.walk(subdir) + for f in filenames + ) + + if len(subdirs) >= max_num or total_size > MAX_SIZE: + oldest_dir = subdirs[0][0] + if os.path.exists(oldest_dir): + shutil.rmtree(oldest_dir) + self.print(f"Deleted oldest ckpt {oldest_dir}") + else: + break + + dist.barrier() + model.save_checkpoint(save_dir, tag=tag, client_state=client_state, save_latest=save_latest) + + def load_ckpt( + self, + model, + load_dir, + tag=None, + load_module_strict=True, + load_optimizer_states=True, + load_lr_scheduler_states=True, + load_module_only=False, + ): + assert isinstance(model, deepspeed.DeepSpeedEngine) + load_path, states = model.load_checkpoint( + load_dir, + tag, + load_module_strict=load_module_strict, + load_optimizer_states=load_optimizer_states, + load_lr_scheduler_states=load_lr_scheduler_states, + load_module_only=load_module_only, + ) + if load_path is None: + raise Exception(f"[deepspeed] failed to resume from checkpoint {load_dir}") + return load_path, states diff --git a/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/utils/deepspeed/deepspeed_utils.py b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/utils/deepspeed/deepspeed_utils.py new file mode 100644 index 0000000000..bb361fc313 --- /dev/null +++ b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/utils/deepspeed/deepspeed_utils.py @@ -0,0 +1,107 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus + + +def get_train_ds_config( + offload, + adam_offload=True, + stage=2, + bf16=True, + max_norm=1.0, + zpg=8, + grad_accum_dtype=None, + overlap_comm=False, +): + device = "cpu" if offload else "none" + zero_opt_dict = { + "stage": stage, + "offload_param": {"device": device}, + "offload_optimizer": { + "device": "cpu" if adam_offload else "none", + "pin_memory": True, + }, + "sub_group_size": "auto", + "stage3_max_live_parameters": "auto", + "stage3_max_reuse_distance": "auto", + "stage3_param_persistence_threshold": "auto", + "stage3_prefetch_bucket_size": "auto", + "reduce_bucket_size": "auto", + # ZeRO++ + "zero_hpz_partition_size": zpg, + "zero_quantized_weights": False, + "zero_quantized_gradients": False, + } + if overlap_comm: + zero_opt_dict["overlap_comm"] = True + zero_opt_dict["contiguous_gradients"] = True + + return { + "steps_per_print": 100, + "zero_optimization": zero_opt_dict, + "bf16": { + "enabled": bf16, + }, + "gradient_clipping": max_norm, + "prescale_gradients": False, + "wall_clock_breakdown": False, + "data_types": {"grad_accum_dtype": grad_accum_dtype}, + } + + +def get_eval_ds_config( + offload, + stage=0, + bf16=True, +): + zero_opt_dict = { + "stage": stage, + "stage3_param_persistence_threshold": "auto", + "offload_param": { + "device": "cpu" if offload else "none", + "pin_memory": True, + }, + } + return { + "steps_per_print": 100, + "zero_optimization": zero_opt_dict, + "bf16": { + "enabled": bf16, + }, + "gradient_clipping": 1.0, + "prescale_gradients": False, + "wall_clock_breakdown": False, + } + + +def get_optimizer_grouped_parameters( + model, + weight_decay, + no_decay_name_list=["bias", "layer_norm.weight", "layernorm.weight", "norm.weight", "ln_f.weight"], +): + optimizer_grouped_parameters = [ + { + "params": [ + p + for n, p in model.named_parameters() + if (not any(nd in n for nd in no_decay_name_list) and p.requires_grad) + ], + "weight_decay": weight_decay, + }, + { + "params": [ + p + for n, p in model.named_parameters() + if (any(nd in n for nd in no_decay_name_list) and p.requires_grad) + ], + "weight_decay": 0.0, + }, + ] + return optimizer_grouped_parameters + + +def _z3_params_to_fetch(param_list): + return [p for p in param_list if hasattr(p, "ds_id") and p.ds_status == ZeroParamStatus.NOT_AVAILABLE] diff --git a/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/utils/distributed_sampler.py b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/utils/distributed_sampler.py new file mode 100644 index 0000000000..1f76582028 --- /dev/null +++ b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/utils/distributed_sampler.py @@ -0,0 +1,151 @@ +import math +from typing import Iterator, Optional, TypeVar + +import torch +import torch.distributed as dist +from torch.utils.data.dataset import Dataset +from torch.utils.data.sampler import Sampler + + +__all__ = ["DistributedSampler"] + + +_T_co = TypeVar("_T_co", covariant=True) + + +# Adapted from https://github.com/pytorch/pytorch/blob/5298acb5c76855bc5a99ae10016efc86b27949bd/torch/utils/data/distributed.py +class DistributedSampler(Sampler[_T_co]): + r"""Sampler that restricts data loading to a subset of the dataset. + + It is especially useful in conjunction with + :class:`torch.nn.parallel.DistributedDataParallel`. In such a case, each + process can pass a :class:`~torch.utils.data.DistributedSampler` instance as a + :class:`~torch.utils.data.DataLoader` sampler, and load a subset of the + original dataset that is exclusive to it. + + .. note:: + Dataset is assumed to be of constant size and that any instance of it always + returns the same elements in the same order. + + Args: + dataset: Dataset used for sampling. + num_replicas (int, optional): Number of processes participating in + distributed training. By default, :attr:`world_size` is retrieved from the + current distributed group. + rank (int, optional): Rank of the current process within :attr:`num_replicas`. + By default, :attr:`rank` is retrieved from the current distributed + group. + shuffle (bool, optional): If ``True`` (default), sampler will shuffle the + indices. + seed (int, optional): random seed used to shuffle the sampler if + :attr:`shuffle=True`. This number should be identical across all + processes in the distributed group. Default: ``0``. + drop_last (bool, optional): if ``True``, then the sampler will drop the + tail of the data to make it evenly divisible across the number of + replicas. If ``False``, the sampler will add extra indices to make + the data evenly divisible across the replicas. Default: ``False``. + + .. warning:: + In distributed mode, calling the :meth:`set_epoch` method at + the beginning of each epoch **before** creating the :class:`DataLoader` iterator + is necessary to make shuffling work properly across multiple epochs. Otherwise, + the same ordering will be always used. + + Example:: + + >>> # xdoctest: +SKIP + >>> sampler = DistributedSampler(dataset) if is_distributed else None + >>> loader = DataLoader(dataset, shuffle=(sampler is None), + ... sampler=sampler) + >>> for epoch in range(start_epoch, n_epochs): + ... if is_distributed: + ... sampler.set_epoch(epoch) + ... train(loader) + """ + + def __init__( + self, + dataset: Dataset, + num_replicas: Optional[int] = None, + rank: Optional[int] = None, + shuffle: bool = True, + seed: int = 0, + drop_last: bool = False, + consumed_samples=0, + ) -> None: + if num_replicas is None: + if not dist.is_available(): + raise RuntimeError("Requires distributed package to be available") + num_replicas = dist.get_world_size() + if rank is None: + if not dist.is_available(): + raise RuntimeError("Requires distributed package to be available") + rank = dist.get_rank() + if rank >= num_replicas or rank < 0: + raise ValueError(f"Invalid rank {rank}, rank should be in the interval [0, {num_replicas - 1}]") + self.dataset = dataset + self.num_replicas = num_replicas + self.rank = rank + self.epoch = 0 + self.drop_last = drop_last + # If the dataset length is evenly divisible by # of replicas, then there + # is no need to drop any data, since the dataset will be split equally. + if self.drop_last and len(self.dataset) % self.num_replicas != 0: # type: ignore[arg-type] + # Split to nearest available length that is evenly divisible. + # This is to ensure each rank receives the same amount of data when + # using this Sampler. + self.num_samples = math.ceil( + (len(self.dataset) - self.num_replicas) / self.num_replicas # type: ignore[arg-type] + ) + else: + self.num_samples = math.ceil(len(self.dataset) / self.num_replicas) # type: ignore[arg-type] + self.total_size = self.num_samples * self.num_replicas + self.shuffle = shuffle + self.seed = seed + self.consumed_indicies = consumed_samples // self.num_replicas + + def __iter__(self) -> Iterator[_T_co]: + if self.shuffle: + # deterministically shuffle based on epoch and seed + g = torch.Generator() + g.manual_seed(self.seed + self.epoch) + indices = torch.randperm(len(self.dataset), generator=g).tolist() # type: ignore[arg-type] + else: + indices = list(range(len(self.dataset))) # type: ignore[arg-type] + + if not self.drop_last: + # add extra samples to make it evenly divisible + padding_size = self.total_size - len(indices) + if padding_size <= len(indices): + indices += indices[:padding_size] + else: + indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size] + else: + # remove tail of data to make it evenly divisible. + indices = indices[: self.total_size] + assert len(indices) == self.total_size + + # subsample + indices = indices[self.rank : self.total_size : self.num_replicas] + # skip consumed_samples + indices = indices[self.consumed_indicies :] + assert len(indices) == self.num_samples - self.consumed_indicies + + return iter(indices) + + def __len__(self) -> int: + return self.num_samples - self.consumed_indicies + + def set_epoch(self, epoch: int, consumed_samples=0) -> None: + r""" + Set the epoch for this sampler. + + When :attr:`shuffle=True`, this ensures all replicas + use a different random ordering for each epoch. Otherwise, the next iteration of this + sampler will yield the same ordering. + + Args: + epoch (int): Epoch number. + """ + self.epoch = epoch + self.consumed_indicies = consumed_samples // self.num_replicas diff --git a/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/utils/distributed_util.py b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/utils/distributed_util.py new file mode 100644 index 0000000000..9b91044cda --- /dev/null +++ b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/utils/distributed_util.py @@ -0,0 +1,72 @@ +from datetime import timedelta +from typing import Any, Optional, Union + +import torch +import torch.distributed +from torch.distributed.distributed_c10d import ( + Backend, + PrefixStore, + Store, + _new_process_group_helper, + _world, + default_pg_timeout, + rendezvous, +) + + +# Copy from pytorch to allow creating multiple main groups. +# https://github.com/pytorch/pytorch/blob/main/torch/distributed/distributed_c10d.py +def init_process_group( + backend: Union[str, Backend] = None, + init_method: Optional[str] = None, + timeout: Optional[timedelta] = None, + world_size: int = -1, + rank: int = -1, + store: Optional[Store] = None, + group_name: str = None, + pg_options: Optional[Any] = None, +): + assert (store is None) or (init_method is None), "Cannot specify both init_method and store." + + if store is not None: + assert world_size > 0, "world_size must be positive if using store" + assert rank >= 0, "rank must be non-negative if using store" + elif init_method is None: + init_method = "env://" + + if backend: + backend = Backend(backend) + else: + backend = Backend("undefined") + + if timeout is None: + timeout = default_pg_timeout + + # backward compatible API + if store is None: + rendezvous_iterator = rendezvous(init_method, rank, world_size, timeout=timeout) + store, rank, world_size = next(rendezvous_iterator) + store.set_timeout(timeout) + + # Use a PrefixStore to avoid accidental overrides of keys used by + # different systems (e.g. RPC) in case the store is multi-tenant. + store = PrefixStore(group_name, store) + + # NOTE: The pg_options parameter was renamed into backend_options in PyTorch 2.6.0 + # https://github.com/pytorch/pytorch/commit/a0c7029a75628cd5fa8df83c0de0ea98ee7fd844 + # We need to determine the appropriate parameter name based on PyTorch version + pg_options_param_name = "backend_options" if str(torch.__version__) >= "2.6" else "pg_options" + pg, _ = _new_process_group_helper( + world_size, + rank, + [], + backend, + store, + group_name=group_name, + **{pg_options_param_name: pg_options}, + timeout=timeout, + ) + + _world.pg_group_ranks[pg] = {i: i for i in range(world_size)} + + return pg diff --git a/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/utils/logging_utils.py b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/utils/logging_utils.py new file mode 100644 index 0000000000..eb39f39a23 --- /dev/null +++ b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/utils/logging_utils.py @@ -0,0 +1,56 @@ +# Adapted from +# https://github.com/skypilot-org/skypilot/blob/86dc0f6283a335e4aa37b3c10716f90999f48ab6/sky/sky_logging.py +"""Logging configuration for vLLM.""" +import logging +import sys + +_FORMAT = "%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s" +_DATE_FORMAT = "%m-%d %H:%M:%S" + + +class NewLineFormatter(logging.Formatter): + """Adds logging prefix to newlines to align multi-line messages.""" + + def __init__(self, fmt, datefmt=None): + logging.Formatter.__init__(self, fmt, datefmt) + + def format(self, record): + msg = logging.Formatter.format(self, record) + if record.message != "": + parts = msg.split(record.message) + msg = msg.replace("\n", "\r\n" + parts[0]) + return msg + + +_root_logger = logging.getLogger("openrlhf") +_default_handler = None + + +def _setup_logger(): + _root_logger.setLevel(logging.DEBUG) + global _default_handler + if _default_handler is None: + _default_handler = logging.StreamHandler(sys.stdout) + _default_handler.flush = sys.stdout.flush # type: ignore + _default_handler.setLevel(logging.INFO) + _root_logger.addHandler(_default_handler) + fmt = NewLineFormatter(_FORMAT, datefmt=_DATE_FORMAT) + _default_handler.setFormatter(fmt) + # Setting this will avoid the message + # being propagated to the parent logger. + _root_logger.propagate = False + + +# The logger is initialized when the module is imported. +# This is thread-safe as the module is only imported once, +# guaranteed by the Python GIL. +_setup_logger() + + +def init_logger(name: str): + # Use the same settings as above for root logger + logger = logging.getLogger(name) + logger.setLevel(logging.DEBUG) + logger.addHandler(_default_handler) + logger.propagate = False + return logger diff --git a/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/utils/processor.py b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/utils/processor.py new file mode 100644 index 0000000000..96036f4dee --- /dev/null +++ b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/utils/processor.py @@ -0,0 +1,103 @@ +import torch +from tqdm import tqdm + + +def reward_normalization(objs): + rewards = [float(obj["reward"]) for obj in objs] + rewards = torch.tensor(rewards, dtype=torch.float64) + rewards = (rewards - rewards.mean()) / rewards.std() + for i, obj in enumerate(objs): + obj["reward"] = rewards[i].item() + + +# Conditional SFT +# See https://arxiv.org/abs/2308.12050 +DEFAULT_REWARD_PROMPT = "{input} : {reward} " + + +def conditional_sft_processor(args, objs): + if "reward_template" not in args or args.reward_template is None: + reward_template = DEFAULT_REWARD_PROMPT + else: + reward_template = args.reward_template + assert "{input}" in reward_template + assert "{reward}" in reward_template + + if args.normalize_reward: + reward_normalization(objs) + + for obj in tqdm(objs, desc="Conditional SFT process..."): + input = obj["input"] + reward = "{:.2f}".format(float(obj["reward"])) + input = reward_template.replace("{reward}", reward).replace("{input}", input) + obj["input"] = input + + return objs + + +# Rejection Sampling +# See https://arxiv.org/abs/2307.09288 +def rejection_sampling_processor(args, objs): + out = {} + for obj in tqdm(objs, desc="Rejection Sampling process...."): + input = obj["input"] + output = obj["output"] + reward = float(obj["reward"]) + + if input not in out: + out[input] = {"output": output, "reward": reward} + elif reward > out[input]["reward"]: + out[input]["reward"] = reward + out[input]["output"] = output + + return [{"input": k, "output": v["output"], "reward": v["reward"]} for k, v in out.items()] + + +# Iterative DPO +# See https://github.com/RLHFlow/Online-RLHF/blob/main/run_loop.sh +def iterative_dpo_processor(args, objs): + out = {} + for obj in tqdm(objs, desc="Iterative DPO process...."): + input = obj["input"] + output = obj["output"] + reward = float(obj["reward"]) + + if input not in out: + out[input] = { + "output": output, + "chosen": output, + "chosen_reward": reward, + "rejected": output, + "rejected_reward": reward, + } + elif reward > out[input]["chosen_reward"]: + out[input]["chosen_reward"] = reward + out[input]["chosen"] = output + elif reward < out[input]["rejected_reward"]: + out[input]["rejected_reward"] = reward + out[input]["rejected"] = output + + return [ + { + "prompt": k, + "chosen": v["chosen"], + "chosen_reward": v["chosen_reward"], + "rejected": v["rejected"], + "rejected_reward": v["rejected_reward"], + } + for k, v in out.items() + ] + + +PROCESSORS = { + "rs": rejection_sampling_processor, + "csft": conditional_sft_processor, + "iter_dpo": iterative_dpo_processor, +} + + +def get_processor(name): + if name in PROCESSORS: + return PROCESSORS[name] + else: + raise ValueError(f"Processor {name} does not exist.") diff --git a/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/utils/remote_rm_utils.py b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/utils/remote_rm_utils.py new file mode 100644 index 0000000000..aef2941c46 --- /dev/null +++ b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/utils/remote_rm_utils.py @@ -0,0 +1,52 @@ +import time +import ray +import requests +import torch + +from openrlhf.utils.logging_utils import init_logger + +logger = init_logger(__name__) + + +def request_api_wrapper(url, data, score_key="rewards", try_max_times=5): + """Synchronous request API wrapper""" + headers = { + "Content-Type": "application/json", + } + for _ in range(try_max_times): + try: + response = requests.post(url=url, json=data, headers=headers, timeout=180) + response.raise_for_status() # Raise an HTTPError for bad responses + response = response.json() + assert score_key in response, f"{score_key} not in {response}" + return response.get(score_key) + except requests.RequestException as e: + logger.info(f"Request error, please check: {e}") + except Exception as e: + logger.info(f"Unexpected error, please check: {e}") + time.sleep(1) + + raise Exception(f"Request error for {try_max_times} times, returning None. Please check the API server.") + + +def remote_rm_fn(api_url, queries, score_key="rewards"): + """remote reward model API + api_url: RM API, We assume that the API supports two modes: merging query + response and not merging + queries: query+response with the template + design is made optional. + score_key: RM score key + """ + scores = request_api_wrapper(api_url, {"query": queries}, score_key) + return torch.tensor(scores) + + +@ray.remote +def remote_rm_fn_ray(api_url, queries, score_key="rewards"): + return remote_rm_fn(api_url, queries, score_key) + + +if __name__ == "__main__": + # test utils + url = "http:xxx/get_rm_score" + score = remote_rm_fn(url, ["example query"], ["example response"]) + print(score) diff --git a/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/utils/utils.py b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/utils/utils.py new file mode 100644 index 0000000000..a69b13ece3 --- /dev/null +++ b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/utils/utils.py @@ -0,0 +1,125 @@ +import os + +from datasets import interleave_datasets, load_dataset, load_from_disk +from transformers import AutoTokenizer + + +def get_tokenizer(pretrain, model, padding_side="left", strategy=None, use_fast=True): + tokenizer = AutoTokenizer.from_pretrained(pretrain, trust_remote_code=True, use_fast=use_fast) + tokenizer.padding_side = padding_side + # NOTE: When enable vLLM, do not resize_token_embeddings, or the vocab size will mismatch with vLLM. + # https://github.com/facebookresearch/llama-recipes/pull/196 + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + tokenizer.pad_token_id = tokenizer.eos_token_id + model.config.pad_token_id = tokenizer.pad_token_id + + return tokenizer + + +def get_strategy(args): + from openrlhf.utils.deepspeed import DeepspeedStrategy + + strategy = DeepspeedStrategy( + seed=getattr(args, "seed", 42), + max_norm=getattr(args, "max_norm", 1.0), + micro_train_batch_size=getattr(args, "micro_train_batch_size", 1), + train_batch_size=getattr(args, "train_batch_size", 128), + zero_stage=args.zero_stage, + bf16=getattr(args, "bf16", True), + args=args, + ) + return strategy + + +def blending_datasets( + datasets, + probabilities, + strategy=None, + seed=42, + max_count=5000000, + return_eval=True, + stopping_strategy="first_exhausted", + train_split="train", + eval_split="test", +): + datasets = datasets.split(",") + probabilities = list(map(float, probabilities.split(","))) + assert len(probabilities) == len(datasets) + + train_data_list = [] + eval_data_list = [] + for i, dataset in enumerate(datasets): + dataset = dataset.strip() + strategy.print(f"dataset: {dataset}") + + data_dir = dataset.split("@")[1].strip() if "@" in dataset else None + dataset = dataset.split("@")[0].strip() + dataset_basename = os.path.basename(dataset) + + ext = os.path.splitext(dataset)[-1] + # local python script + if ext == ".py" or ( + os.path.isdir(dataset) and os.path.exists(os.path.join(dataset, f"{dataset_basename}.py")) + ): + data = load_dataset(dataset, trust_remote_code=True) + strategy.print(f"loaded {dataset} with python script") + # local text file + elif ext in [".json", ".jsonl", ".csv"]: + ext = ext.lower().strip(".") + if ext == "jsonl": + ext = "json" + data = load_dataset(ext, data_files=dataset) + strategy.print(f"loaded {dataset} with data_files={dataset}") + # local dataset saved with `datasets.Dataset.save_to_disk` + elif os.path.isdir(dataset): + data = load_from_disk(dataset) + strategy.print(f"loaded {dataset} from disk") + # remote/local folder or common file + else: + data = load_dataset(dataset, data_dir=data_dir) + strategy.print(f"loaded {dataset} from files") + + if train_split and train_split in data: + train_data = data[train_split].select(range(min(max_count, len(data[train_split])))) + else: + train_data = data.select(range(min(max_count, len(data)))) + train_data_list.append(train_data) + + if return_eval: + if eval_split and eval_split in data: + eval_data = data[eval_split].select(range(min(max_count, len(data[eval_split])))) + # train will contains eval? TODO + else: + eval_data = train_data.select(range(min(max_count, int(len(train_data) * 0.03)))) + eval_data_list.append(eval_data) + + # merge datasets + if strategy.is_rank_0(): + print(train_data_list) + + train_dataset = interleave_datasets( + train_data_list, + probabilities=probabilities, + seed=seed, + stopping_strategy=stopping_strategy, + ) + if return_eval: + eval_dataset = interleave_datasets( + eval_data_list, + probabilities=probabilities, + seed=seed, + stopping_strategy=stopping_strategy, + ) + return train_dataset, eval_dataset + else: + return train_dataset + + +def convert_token_to_id(token, tokenizer): + if isinstance(token, str): + token = tokenizer.encode(token, add_special_tokens=False) + assert len(token) == 1 + return token[0] + else: + raise ValueError("token should be int or str") diff --git a/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/pyproject.toml b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/pyproject.toml new file mode 100644 index 0000000000..3c2b7d3baf --- /dev/null +++ b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/pyproject.toml @@ -0,0 +1,56 @@ +[build-system] +requires = [ + "packaging", + "setuptools >= 49.4.0", + "wheel", +] +build-backend = "setuptools.build_meta" + +[tool.isort] +profile = "black" # black-compatible +line_length = 119 # should match black parameters +ignore_whitespace = true # ignore whitespace for compatibility with the initial style +py_version = 310 # python 3.10 as a target version +sections = ["FUTURE", "STDLIB", "THIRDPARTY", "FIRSTPARTY", "LOCALFOLDER"] +default_section = "THIRDPARTY" +extend_skip = ["setup.py", "docs/source/conf.py"] + + +[tool.black] +line_length = 119 + +[tool.ruff] +line-length = 119 + +[tool.pytest.ini_options] +# durations=0 will display all tests execution time, sorted in ascending order starting from from the slowest one. +# -vv will also display tests with durration = 0.00s +addopts = "--verbose --pyargs --durations=0 --strict-markers" # always add these arguments to pytest +testpaths = ["./tests"] # must be an explicit path to avoid importing another "tests" module +# directories to ignore when discovering tests +norecursedirs = [ + "external", + "examples", + "docs", + "scripts", + "tools", + "tutorials", + "*.egg", + ".*", + "_darcs", + "build", + "CVS", + "dist", + "venv", + "{arch}", +] +# markers to select tests, use `pytest --markers` to see all available markers, `pytest -m ""` to select tests +markers = [ + "unit: marks unit test, i.e. testing a single, well isolated functionality (deselect with '-m \"not unit\"')", + "integration: marks test checking the elements when integrated into subsystems (deselect with '-m \"not integration\"')", + "system: marks test working at the highest integration level (deselect with '-m \"not system\"')", + "acceptance: marks test checking whether the developed product/model passes the user defined acceptance criteria (deselect with '-m \"not acceptance\"')", + "docs: mark tests related to documentation (deselect with '-m \"not docs\"')", + "skipduringci: marks tests that are skipped ci as they are addressed by Jenkins jobs but should be run to test user setups", + "pleasefixme: marks tests that are broken and need fixing", +] diff --git a/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/requirements.txt b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/requirements.txt new file mode 100644 index 0000000000..3cdcd77765 --- /dev/null +++ b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/requirements.txt @@ -0,0 +1,21 @@ +accelerate +bitsandbytes +datasets +deepspeed==0.15.0 +einops +flash-attn==2.7.0.post2 +isort +jsonlines +loralib +optimum +packaging +peft +ray[default]==2.12.0 +tensorboard +torch +torchmetrics +tqdm +transformers==4.46.3 +transformers_stream_generator +wandb +wheel diff --git a/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/setup.py b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/setup.py new file mode 100644 index 0000000000..be536a7366 --- /dev/null +++ b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/setup.py @@ -0,0 +1,89 @@ +import os +import sys +import platform + +from datetime import datetime +from setuptools import find_packages, setup +from wheel.bdist_wheel import bdist_wheel as _bdist_wheel + +_build_mode = os.getenv("OPENRLHF_BUILD_MODE", "") + + +def _is_nightly(): + return _build_mode.lower() == "nightly" + + +def _fetch_requirements(path): + with open(path, "r") as fd: + return [r.strip() for r in fd.readlines()] + + +def _fetch_readme(): + with open("README.md", encoding="utf-8") as f: + return f.read() + + +def _fetch_version(): + with open("version.txt", "r") as f: + version = f.read().strip() + + if _is_nightly(): + now = datetime.now() + date_str = now.strftime("%Y%m%d") + version += f".dev{date_str}" + + return version + + +def _fetch_package_name(): + return "openrlhf-nightly" if _is_nightly() else "openrlhf" + + +# Custom wheel class to modify the wheel name +class bdist_wheel(_bdist_wheel): + def finalize_options(self): + _bdist_wheel.finalize_options(self) + self.root_is_pure = False + + def get_tag(self): + python_version = f"cp{sys.version_info.major}{sys.version_info.minor}" + abi_tag = f"{python_version}" + + if platform.system() == "Linux": + platform_tag = "manylinux1_x86_64" + else: + platform_tag = platform.system().lower() + + return python_version, abi_tag, platform_tag + + +# Setup configuration +setup( + author="OpenRLHF Team", + name=_fetch_package_name(), + version=_fetch_version(), + packages=find_packages( + exclude=( + "data", + "docs", + "examples", + ) + ), + description="A Ray-based High-performance RLHF framework.", + long_description=_fetch_readme(), + long_description_content_type="text/markdown", + install_requires=_fetch_requirements("requirements.txt"), + extras_require={ + "vllm": ["vllm==0.6.5"], + "vllm_latest": ["vllm>0.6.5"], + }, + python_requires=">=3.10", + classifiers=[ + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Environment :: GPU :: NVIDIA CUDA", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: System :: Distributed Computing", + ], + cmdclass={"bdist_wheel": bdist_wheel}, +) diff --git a/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/version.txt b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/version.txt new file mode 100644 index 0000000000..dc2b74e60a --- /dev/null +++ b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/version.txt @@ -0,0 +1 @@ +0.5.7 \ No newline at end of file -- Gitee