diff --git a/PyTorch/built-in/rl/OpenRLHF_v0.6.2_for_PyTorch/openrlhf/__init__.py b/PyTorch/built-in/rl/OpenRLHF_v0.6.2_for_PyTorch/openrlhf/__init__.py index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..9bad7ff74e19de39ce183f7197ae102fcc6fca3e 100644 --- a/PyTorch/built-in/rl/OpenRLHF_v0.6.2_for_PyTorch/openrlhf/__init__.py +++ b/PyTorch/built-in/rl/OpenRLHF_v0.6.2_for_PyTorch/openrlhf/__init__.py @@ -0,0 +1,10 @@ +import importlib + + +ACCELERATOR_TYPE = "GPU" + +if importlib.util.find_spec("torch_npu"): + ACCELERATOR_TYPE = "NPU" + + import torch_npu + from torch_npu.contrib import transfer_to_npu diff --git a/PyTorch/built-in/rl/OpenRLHF_v0.6.2_for_PyTorch/openrlhf/cli/train_ppo_ray.py b/PyTorch/built-in/rl/OpenRLHF_v0.6.2_for_PyTorch/openrlhf/cli/train_ppo_ray.py index e0c1756fa192f36200ae339788843d8c4b2bd922..0052989d63466d0e984b0dbb505f4b042c5c0727 100644 --- a/PyTorch/built-in/rl/OpenRLHF_v0.6.2_for_PyTorch/openrlhf/cli/train_ppo_ray.py +++ b/PyTorch/built-in/rl/OpenRLHF_v0.6.2_for_PyTorch/openrlhf/cli/train_ppo_ray.py @@ -15,6 +15,7 @@ from openrlhf.trainer.ray import ( create_vllm_engines, ) from openrlhf.utils import get_strategy +from openrlhf import ACCELERATOR_TYPE # NOTE: reward function for multiple reward models, replace this with your own function! @@ -65,7 +66,7 @@ def train(args): and args.actor_num_gpus_per_node == args.ref_num_gpus_per_node ), f"num_nodes and num_gpus_per_node must be the same when colocate actor and ref model." - bundles = [{"GPU": 1, "CPU": 1} for _ in range(args.actor_num_nodes * args.actor_num_gpus_per_node)] + bundles = [{ACCELERATOR_TYPE: 1, "CPU": 1} for _ in range(args.actor_num_nodes * args.actor_num_gpus_per_node)] pg = placement_group(bundles, strategy="PACK") ray.get(pg.ready()) @@ -126,7 +127,7 @@ def train(args): and args.critic_num_gpus_per_node == args.reward_num_gpus_per_node ), f"num_nodes and num_gpus_per_node must be the same when colocate critic and reward model." - bundles = [{"GPU": 1, "CPU": 1} for _ in range(args.critic_num_nodes * args.critic_num_gpus_per_node)] + bundles = [{ACCELERATOR_TYPE: 1, "CPU": 1} for _ in range(args.critic_num_nodes * args.critic_num_gpus_per_node)] pg = placement_group(bundles, strategy="PACK") ray.get(pg.ready()) @@ -188,6 +189,8 @@ def train(args): if args.critic_pretrain and args.save_value_network: ray.get(critic_model.async_save_model()) + # temp solution: Avoid main process disappeared error on Ascend NPU + ray.shutdown() if __name__ == "__main__": parser = argparse.ArgumentParser() diff --git a/PyTorch/built-in/rl/OpenRLHF_v0.6.2_for_PyTorch/openrlhf/models/actor.py b/PyTorch/built-in/rl/OpenRLHF_v0.6.2_for_PyTorch/openrlhf/models/actor.py index 8027665d5e24eca11143ad4eb5d06cab59672564..9b6e444f70f334f2a25f6f843f0e3e3e0b4d115a 100644 --- a/PyTorch/built-in/rl/OpenRLHF_v0.6.2_for_PyTorch/openrlhf/models/actor.py +++ b/PyTorch/built-in/rl/OpenRLHF_v0.6.2_for_PyTorch/openrlhf/models/actor.py @@ -3,16 +3,19 @@ from typing import Optional, Tuple, Union import torch import torch.distributed as dist import torch.nn as nn -from flash_attn.utils.distributed import all_gather from peft import LoraConfig, TaskType, get_peft_model from peft.tuners.lora import LoraLayer from torch.nn import functional as F from transformers import AutoModelForCausalLM, BitsAndBytesConfig from transformers.integrations.deepspeed import HfDeepSpeedConfig +from transformers.utils import is_flash_attn_2_available from .ring_attn_utils import convert_ring_attn_params from .utils import log_probs_from_logits, reset_position_ids +if is_flash_attn_2_available(): + from flash_attn.utils.distributed import all_gather + class Actor(nn.Module): """ diff --git a/PyTorch/built-in/rl/OpenRLHF_v0.6.2_for_PyTorch/openrlhf/models/model.py b/PyTorch/built-in/rl/OpenRLHF_v0.6.2_for_PyTorch/openrlhf/models/model.py index 49d8af1050137e3f8afe4e841a9cb1dafb3c887a..e48f35cb8f9ae8854c2484bb38a016a51b941b60 100644 --- a/PyTorch/built-in/rl/OpenRLHF_v0.6.2_for_PyTorch/openrlhf/models/model.py +++ b/PyTorch/built-in/rl/OpenRLHF_v0.6.2_for_PyTorch/openrlhf/models/model.py @@ -3,17 +3,20 @@ from typing import Optional, Union import deepspeed import torch import torch.nn as nn -from flash_attn.utils.distributed import all_gather from peft import LoraConfig, get_peft_model from peft.tuners.lora import LoraLayer from transformers import AutoConfig, AutoModel, BitsAndBytesConfig from transformers.integrations.deepspeed import HfDeepSpeedConfig +from transformers.utils import is_flash_attn_2_available from openrlhf.utils.logging_utils import init_logger from .ring_attn_utils import convert_ring_attn_params from .utils import reset_position_ids +if is_flash_attn_2_available(): + from flash_attn.utils.distributed import all_gather + logger = init_logger(__name__) diff --git a/PyTorch/built-in/rl/OpenRLHF_v0.6.2_for_PyTorch/openrlhf/trainer/dpo_trainer.py b/PyTorch/built-in/rl/OpenRLHF_v0.6.2_for_PyTorch/openrlhf/trainer/dpo_trainer.py index 904698fe0edd15dd8502351d778d06984e32ce1f..e861b80383c6ebbb7b67d0f827f0793629eb8439 100644 --- a/PyTorch/built-in/rl/OpenRLHF_v0.6.2_for_PyTorch/openrlhf/trainer/dpo_trainer.py +++ b/PyTorch/built-in/rl/OpenRLHF_v0.6.2_for_PyTorch/openrlhf/trainer/dpo_trainer.py @@ -2,7 +2,6 @@ import os from abc import ABC import torch -from flash_attn.utils.distributed import all_gather from torch.nn import functional as F from torch.optim import Optimizer from tqdm import tqdm @@ -10,6 +9,10 @@ from tqdm import tqdm from openrlhf.models import DPOLoss from openrlhf.models.utils import log_probs_from_logits from openrlhf.utils.distributed_sampler import DistributedSampler +from transformers.utils import is_flash_attn_2_available + +if is_flash_attn_2_available(): + from flash_attn.utils.distributed import all_gather class DPOTrainer(ABC): diff --git a/PyTorch/built-in/rl/OpenRLHF_v0.6.2_for_PyTorch/openrlhf/trainer/ray/launcher.py b/PyTorch/built-in/rl/OpenRLHF_v0.6.2_for_PyTorch/openrlhf/trainer/ray/launcher.py index b5512821a571f53f95219b60883f23460c8dd1b3..4d18d21f1c1c23e2425acb8b2e5adff7e9957f8b 100644 --- a/PyTorch/built-in/rl/OpenRLHF_v0.6.2_for_PyTorch/openrlhf/trainer/ray/launcher.py +++ b/PyTorch/built-in/rl/OpenRLHF_v0.6.2_for_PyTorch/openrlhf/trainer/ray/launcher.py @@ -11,6 +11,7 @@ from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy from openrlhf.models import Actor, get_llm_for_sequence_regression from openrlhf.trainer.ray.utils import ray_noset_visible_devices from openrlhf.utils.deepspeed import DeepspeedStrategy +from openrlhf import ACCELERATOR_TYPE class DistributedTorchRayActor: @@ -32,7 +33,8 @@ class DistributedTorchRayActor: # environment variable for each actor, unless # RAY_EXPERIMENTAL_NOSET_*_VISIBLE_DEVICES is set, so # set local rank to 0 when the flag is not applicable. - os.environ["LOCAL_RANK"] = str(ray.get_gpu_ids()[0]) if ray_noset_visible_devices() else "0" + os.environ["LOCAL_RANK"] = str(ray.get_runtime_context().get_accelerator_ids()[ACCELERATOR_TYPE][0]) \ + if ray_noset_visible_devices() else "0" @staticmethod def _get_current_node_ip(): @@ -60,7 +62,7 @@ class BasePPORole(DistributedTorchRayActor): raise NotImplementedError() -@ray.remote(num_gpus=1) +@ray.remote class ReferenceModelRayActor(BasePPORole): def init_model_from_pretrained(self, strategy: DeepspeedStrategy, pretrain): self._setup_distributed(strategy) @@ -109,7 +111,7 @@ class ReferenceModelRayActor(BasePPORole): torch.cuda.empty_cache() -@ray.remote(num_gpus=1) +@ray.remote class RewardModelRayActor(BasePPORole): def init_model_from_pretrained(self, strategy: DeepspeedStrategy, pretrain): self._setup_distributed(strategy) @@ -186,18 +188,22 @@ class PPORayActorGroup: self._num_gpus_per_node = num_gpus_per_node self.ray_actor_type = ray_actor_type - # custom resources, see https://docs.ray.io/en/latest/ray-core/scheduling/resources.html - self._resources = resources - self._num_resources_per_node = num_resources_per_node - - self._initiate_actors(pg, num_gpus_per_actor) + if ACCELERATOR_TYPE == "GPU": + # custom resources, see https://docs.ray.io/en/latest/ray-core/scheduling/resources.html + self._resources = resources + self._num_resources_per_node = num_resources_per_node + self._initiate_actors(pg, num_gpus_per_actor) + elif ACCELERATOR_TYPE == "NPU": + self._resources = {ACCELERATOR_TYPE: num_gpus_per_actor} + self._num_resources_per_node = num_gpus_per_actor + self._initiate_actors(pg, 0) def _initiate_actors(self, pg, num_gpus_per_actor): world_size = self._num_nodes * self._num_gpus_per_node # Use placement group to lock resources for models of same type if self._num_gpus_per_node > 1 and pg is None: - bundles = [{"GPU": 1, "CPU": 1} for _ in range(self._num_nodes * self._num_gpus_per_node)] + bundles = [{ACCELERATOR_TYPE: 1, "CPU": 1} for _ in range(self._num_nodes * self._num_gpus_per_node)] if self._resources: resources_name = list(self._resources.keys())[0] for i in range(len(bundles)): diff --git a/PyTorch/built-in/rl/OpenRLHF_v0.6.2_for_PyTorch/openrlhf/trainer/ray/ppo_actor.py b/PyTorch/built-in/rl/OpenRLHF_v0.6.2_for_PyTorch/openrlhf/trainer/ray/ppo_actor.py index 665507e76ff91162a802106d4e5c52af934b4edb..9bcba084777e869759ba1d45b1eaee3e2c7ced6e 100644 --- a/PyTorch/built-in/rl/OpenRLHF_v0.6.2_for_PyTorch/openrlhf/trainer/ray/ppo_actor.py +++ b/PyTorch/built-in/rl/OpenRLHF_v0.6.2_for_PyTorch/openrlhf/trainer/ray/ppo_actor.py @@ -19,6 +19,7 @@ from openrlhf.utils import blending_datasets, get_tokenizer from openrlhf.utils.deepspeed import DeepspeedStrategy from openrlhf.utils.deepspeed.deepspeed_utils import offload_deepspeed_states, reload_deepspeed_states from openrlhf.utils.distributed_util import init_process_group +from openrlhf import ACCELERATOR_TYPE from .launcher import BasePPORole from .utils import get_physical_gpu_id @@ -59,7 +60,11 @@ class ActorPPOTrainer(PPOTrainer): packing_samples=self.strategy.args.packing_samples, ) - backend = getattr(self.strategy.args, "vllm_sync_backend", "nccl") + if ACCELERATOR_TYPE == "GPU": + backend = getattr(self.strategy.args, "vllm_sync_backend", "nccl") + elif ACCELERATOR_TYPE == "NPU": + backend = "hccl" + self.use_cuda_ipc = False if backend == "nccl" and self.strategy.args.colocate_all_models: self.use_cuda_ipc = True @@ -287,7 +292,7 @@ class ActorPPOTrainer(PPOTrainer): offload_deepspeed_states(self.actor.model) -@ray.remote(num_gpus=1) +@ray.remote class ActorModelRayActor(BasePPORole): def init_model_from_pretrained(self, strategy: DeepspeedStrategy, pretrain): args = strategy.args diff --git a/PyTorch/built-in/rl/OpenRLHF_v0.6.2_for_PyTorch/openrlhf/trainer/ray/ppo_critic.py b/PyTorch/built-in/rl/OpenRLHF_v0.6.2_for_PyTorch/openrlhf/trainer/ray/ppo_critic.py index 443a36838ff0f9c3e73718de39683aaff4098eba..79c3c7f3533eab83af82029b7f88b51b6cd0491a 100644 --- a/PyTorch/built-in/rl/OpenRLHF_v0.6.2_for_PyTorch/openrlhf/trainer/ray/ppo_critic.py +++ b/PyTorch/built-in/rl/OpenRLHF_v0.6.2_for_PyTorch/openrlhf/trainer/ray/ppo_critic.py @@ -62,7 +62,7 @@ class CriticPPOTrainer(PPOTrainer): return self.training_step_critic(experience) -@ray.remote(num_gpus=1) +@ray.remote class CriticModelRayActor(BasePPORole): def init_model_from_pretrained(self, strategy: DeepspeedStrategy, pretrain, max_steps): args = strategy.args diff --git a/PyTorch/built-in/rl/OpenRLHF_v0.6.2_for_PyTorch/openrlhf/trainer/ray/vllm_engine.py b/PyTorch/built-in/rl/OpenRLHF_v0.6.2_for_PyTorch/openrlhf/trainer/ray/vllm_engine.py index 7fb0eadfc485d000fbd93b6215092330f54079de..3f77da3d3df607924f536e3a08c4b982086242fe 100644 --- a/PyTorch/built-in/rl/OpenRLHF_v0.6.2_for_PyTorch/openrlhf/trainer/ray/vllm_engine.py +++ b/PyTorch/built-in/rl/OpenRLHF_v0.6.2_for_PyTorch/openrlhf/trainer/ray/vllm_engine.py @@ -9,6 +9,7 @@ from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy from vllm import LLM from openrlhf.utils.logging_utils import init_logger +from openrlhf import ACCELERATOR_TYPE from .utils import ray_noset_visible_devices @@ -36,7 +37,8 @@ class LLMRayActor: # We need to set CUDA_VISIBLE_DEVICES to the ray assigned GPU # when the distributed_executor_backend is not ray and # RAY_EXPERIMENTAL_NOSET_*_VISIBLE_DEVICES is set. - os.environ["CUDA_VISIBLE_DEVICES"] = str(ray.get_gpu_ids()[0]) + if ACCELERATOR_TYPE == "GPU": + os.environ["CUDA_VISIBLE_DEVICES"] = str(ray.get_gpu_ids()[0]) num_gpus = kwargs.pop("num_gpus") if bundle_indices is not None: @@ -127,7 +129,10 @@ def create_vllm_engines( assert vllm.__version__ >= "0.7.2", "OpenRLHF only supports vllm >= 0.7.2" vllm_engines = [] - distributed_executor_backend = "uni" if tensor_parallel_size == 1 else "ray" + if ACCELERATOR_TYPE == "GPU": + distributed_executor_backend = "uni" if tensor_parallel_size == 1 else "ray" + elif ACCELERATOR_TYPE == "NPU": + distributed_executor_backend = "uni" if tensor_parallel_size == 1 else "mp" use_hybrid_engine = shared_pg is not None num_gpus = int(tensor_parallel_size == 1) if use_hybrid_engine and tensor_parallel_size == 1: @@ -135,9 +140,9 @@ def create_vllm_engines( # 2 instances on the same GPUs. num_gpus = 0.2 - if not use_hybrid_engine: + if not use_hybrid_engine and ACCELERATOR_TYPE == "GPU": # Create a big placement group to ensure that all engines are packed - bundles = [{"GPU": 1, "CPU": 1} for _ in range(num_engines * tensor_parallel_size)] + bundles = [{ACCELERATOR_TYPE: 1, "CPU": 1} for _ in range(num_engines * tensor_parallel_size)] shared_pg = placement_group(bundles, strategy="PACK") ray.get(shared_pg.ready()) @@ -160,8 +165,9 @@ def create_vllm_engines( vllm_engines.append( LLMRayActor.options( num_cpus=num_gpus, - num_gpus=num_gpus, - scheduling_strategy=scheduling_strategy, + num_gpus=num_gpus if ACCELERATOR_TYPE == "GPU" else 0, + resources=None if ACCELERATOR_TYPE == "GPU" else {ACCELERATOR_TYPE: tensor_parallel_size}, + scheduling_strategy=scheduling_strategy if ACCELERATOR_TYPE == "GPU" else None, ).remote( model=pretrain, enforce_eager=enforce_eager, @@ -175,7 +181,7 @@ def create_vllm_engines( trust_remote_code=True, num_actors=num_actors, gpu_memory_utilization=gpu_memory_utilization, - bundle_indices=bundle_indices, + bundle_indices=bundle_indices if ACCELERATOR_TYPE == "GPU" else None, num_gpus=0.2 if use_hybrid_engine else 1, enable_sleep_mode=vllm_enable_sleep, noset_visible_devices=ray_noset_visible_devices(), diff --git a/PyTorch/built-in/rl/OpenRLHF_v0.6.2_for_PyTorch/openrlhf/trainer/ray/vllm_worker_wrap.py b/PyTorch/built-in/rl/OpenRLHF_v0.6.2_for_PyTorch/openrlhf/trainer/ray/vllm_worker_wrap.py index 6834855dcd029035f2471e487d44d90e7dc4d308..9f224c8b4f39728ec1dd3a47a2fbe222ee397abb 100644 --- a/PyTorch/built-in/rl/OpenRLHF_v0.6.2_for_PyTorch/openrlhf/trainer/ray/vllm_worker_wrap.py +++ b/PyTorch/built-in/rl/OpenRLHF_v0.6.2_for_PyTorch/openrlhf/trainer/ray/vllm_worker_wrap.py @@ -3,8 +3,12 @@ from vllm.worker.worker import Worker from openrlhf.utils.distributed_util import init_process_group from openrlhf.utils.logging_utils import init_logger +from openrlhf import ACCELERATOR_TYPE from .utils import get_physical_gpu_id +if ACCELERATOR_TYPE == "NPU": + from vllm_ascend.worker.worker import NPUWorker as Worker + logger = init_logger(__name__) diff --git a/PyTorch/built-in/rl/OpenRLHF_v0.6.2_for_PyTorch/requirements-npu.txt b/PyTorch/built-in/rl/OpenRLHF_v0.6.2_for_PyTorch/requirements-npu.txt new file mode 100644 index 0000000000000000000000000000000000000000..f23991cbb0f5483ca2e1dd85a4f562c81353e05b --- /dev/null +++ b/PyTorch/built-in/rl/OpenRLHF_v0.6.2_for_PyTorch/requirements-npu.txt @@ -0,0 +1,19 @@ +torch==2.5.1 +accelerate +datasets +deepspeed==0.16.3 +einops +isort +jsonlines +loralib +optimum +packaging +peft +ray[default]==2.42.0 +tensorboard +torchmetrics +tqdm +transformers==4.51.0 +transformers_stream_generator +wandb +wheel diff --git a/PyTorch/built-in/rl/OpenRLHF_v0.6.2_for_PyTorch/setup.py b/PyTorch/built-in/rl/OpenRLHF_v0.6.2_for_PyTorch/setup.py index 55038e6274ee5ed8a23976a064476486c7643fdf..84c2ae28639503c83531353ebc61e5c299640743 100644 --- a/PyTorch/built-in/rl/OpenRLHF_v0.6.2_for_PyTorch/setup.py +++ b/PyTorch/built-in/rl/OpenRLHF_v0.6.2_for_PyTorch/setup.py @@ -56,6 +56,14 @@ class bdist_wheel(_bdist_wheel): return python_version, abi_tag, platform_tag +target_device = os.getenv("TARGET_DEVICE", "GPU").upper() + +if target_device == "NPU": + requirements_file = "requirements-npu.txt" +else: + requirements_file = "requirements.txt" + +install_requires = _fetch_requirements(requirements_file) # Setup configuration setup( @@ -72,7 +80,7 @@ setup( description="A Ray-based High-performance RLHF framework.", long_description=_fetch_readme(), long_description_content_type="text/markdown", - install_requires=_fetch_requirements("requirements.txt"), + install_requires=install_requires, extras_require={ "vllm": ["vllm==0.7.2"], "vllm_latest": ["vllm>0.7.2"],