diff --git a/PyTorch/built-in/rl/VeRL_for_PyTorch/requirements-npu.txt b/PyTorch/built-in/rl/VeRL_for_PyTorch/requirements-npu.txt new file mode 100644 index 0000000000000000000000000000000000000000..d6b8f48c41254a0d6fc948b63582a498f575d030 --- /dev/null +++ b/PyTorch/built-in/rl/VeRL_for_PyTorch/requirements-npu.txt @@ -0,0 +1,17 @@ +# requirements.txt records the full set of dependencies for development +accelerate +codetiming +datasets +dill +hydra-core +numpy +pandas +peft +pyarrow>=15.0.0 +pybind11 +pylatexenc +ray +tensordict<0.6 +transformers>=4.51.0 +torchvision==0.20.1 +wandb \ No newline at end of file diff --git a/PyTorch/built-in/rl/VeRL_for_PyTorch/test/env_npu.sh b/PyTorch/built-in/rl/VeRL_for_PyTorch/test/env_npu.sh new file mode 100644 index 0000000000000000000000000000000000000000..9a3e594385476ea32d23cbe10ab5756ff1114b98 --- /dev/null +++ b/PyTorch/built-in/rl/VeRL_for_PyTorch/test/env_npu.sh @@ -0,0 +1,45 @@ +#!/bin/bash +CANN_INSTALL_PATH_CONF='/etc/Ascend/ascend_cann_install.info' + +if [ -f $CANN_INSTALL_PATH_CONF ]; then + CANN_INSTALL_PATH=$(cat $CANN_INSTALL_PATH_CONF | grep Install_Path | cut -d "=" -f 2) +else + CANN_INSTALL_PATH="/usr/local/Ascend" +fi + +if [ -d ${CANN_INSTALL_PATH}/ascend-toolkit/latest ]; then + source ${CANN_INSTALL_PATH}/ascend-toolkit/set_env.sh + source ${CANN_INSTALL_PATH}/nnal/atb/set_env.sh +else + source ${CANN_INSTALL_PATH}/nnae/set_env.sh +fi +msnpureport -g error -d 0 +msnpureport -g error -d 1 +msnpureport -g error -d 2 +msnpureport -g error -d 3 +msnpureport -g error -d 4 +msnpureport -g error -d 5 +msnpureport -g error -d 6 +msnpureport -g error -d 7 + + +#将Host日志输出到串口,0-关闭/1-开启。指定0关闭日志打屏,即日志采用默认输出方式,将日志保存在log文件中。 +export ASCEND_SLOG_PRINT_TO_STDOUT=0 +#设置默认日志级别,0-debug/1-info/2-warning/3-error。此处指定3输出error级别日志,可根据具体需要调整。 +export ASCEND_GLOBAL_LOG_LEVEL=3 +#设置应用类日志是否开启Event日志。0-关闭/1-开启,默认值为1,此处设置为0表示关闭Event日志。 +export ASCEND_GLOBAL_EVENT_ENABLE=0 + +#可通过此环境变量配置task_queue算子下发队列是否开启和优化等级。 +#-配置为0时:关闭task_queue算子下发队列优化。 +#-配置为1或未配置时:开启task_queue算子下发队列Level 1优化。 +#-配置为2时:开启task_queue算子下发队列Level 2优化。关于Level 1和Level 2优化的详细解释请查看官网文档。 +export TASK_QUEUE_ENABLE=1 + +#设置是否开启fftsplus,0-关闭/1-开启 +export ASCEND_ENHANCE_ENABLE=1 +#HCCL白名单开关,1-关闭/0-开启。设置为1则无需校验HCCL通信白名单。 +export HCCL_WHITELIST_DISABLE=1 +export HCCL_IF_IP=$(hostname -I |awk '{print $1}') +#分布式训练或推理场景下,用于限制不同设备之间socket建链过程的超时等待时间。该环境变量需要配置为整数。此处为试验后的经验值。 +export HCCL_CONNECT_TIMEOUT=5400 diff --git a/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/__init__.py b/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/__init__.py index 35d9d9162e160fc76962bcadf41cdcf5b7513090..a867d6a22d5b76267a90c3a6c483ab8c3a7e957a 100644 --- a/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/__init__.py +++ b/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/__init__.py @@ -14,7 +14,10 @@ import os import math - +import pkg_resources +from pkg_resources import DistributionNotFound +from packaging.version import parse as parse_version +from .utils.device import is_npu_available version_folder = os.path.dirname(os.path.join(os.path.abspath(__file__))) with open(os.path.join(version_folder, 'version/version')) as f: @@ -24,6 +27,8 @@ from .protocol import DataProto from .utils.logging_utils import set_basic_config import logging +if is_npu_available: + from .utils import npu_patch set_basic_config(level=logging.WARNING) @@ -31,6 +36,19 @@ from . import single_controller __all__ = ['DataProto', "__version__"] +package_name = 'transformers' +required_version_spec = '4.51.0' +try: + installed_version = pkg_resources.get_distribution(package_name).version + installed = parse_version(installed_version) + required = parse_version(required_version_spec) + + if not installed >= required: + raise ValueError(f"{package_name} version required >= {required_version_spec}, current version is {installed}.") +except DistributionNotFound as e: + raise ImportError( + f"{package_name} not installed。please run pip install {package_name}=={required_version_spec}") from e + if os.getenv('VERL_USE_MODELSCOPE', 'False').lower() == 'true': import importlib if importlib.util.find_spec("modelscope") is None: diff --git a/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/models/transformers/qwen2_vl.py b/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/models/transformers/qwen2_vl.py index 718b9ca6f5b8a5569c790d47e586e0257d0f4ebc..97ee7f346688326ca08776cf9cd4d3da4881270c 100644 --- a/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/models/transformers/qwen2_vl.py +++ b/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/models/transformers/qwen2_vl.py @@ -22,7 +22,7 @@ from verl.utils.ulysses import gather_heads_scatter_seq, gather_seq_scatter_head get_ulysses_sequence_parallel_world_size, validate_ulysses_config try: - from flash_attn import flash_attn_func, flash_attn_varlen_func + from transformers.modeling_flash_attention_utils import flash_attn_func, flash_attn_varlen_func _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters) except ImportError: diff --git a/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/protocol.py b/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/protocol.py index 847bc92a7863dff74ddaf2322715c659628710be..14e907390404405942268b0226ac03cf88677c03 100644 --- a/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/protocol.py +++ b/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/protocol.py @@ -29,6 +29,7 @@ from tensordict import TensorDict from torch.utils.data import DataLoader, Dataset from verl.utils.py_functional import union_two_dict +from verl.utils.device import get_torch_device __all__ = ['DataProto', 'union_tensor_dict'] @@ -766,7 +767,7 @@ def all_gather_data_proto(data: DataProto, process_group): group_size = torch.distributed.get_world_size(group=process_group) assert isinstance(data, DataProto) prev_device = data.batch.device - data.batch = data.batch.cuda(device=torch.cuda.current_device()) + data.batch = data.batch.to(get_torch_device().current_device()) data.batch = allgather_dict_tensors(data.batch.contiguous(), size=group_size, group=process_group, dim=0) data.batch = data.batch.to(prev_device) # all gather non_tensor_batch diff --git a/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/single_controller/base/worker.py b/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/single_controller/base/worker.py index dbe4cc600fa9d64a51ee420508bfb4b5b4db1351..fc42078e252749745ec1e25da55511d9d57965fb 100644 --- a/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/single_controller/base/worker.py +++ b/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/single_controller/base/worker.py @@ -17,6 +17,7 @@ the class for Worker import os import socket from dataclasses import dataclass +from ...utils.device import get_device_name from .decorator import register, Dispatch, Execute @@ -123,10 +124,10 @@ class Worker(WorkerHelper): # [SUPPORT AMD: torch] import torch ### - + device_name = get_device_name() ### # [SUPPORT AMD: torch] - if "AMD" in torch.cuda.get_device_name(): + if "AMD" in device_name: os.environ['CUDA_VISIBLE_DEVICES'] = os.environ.get('ROCR_VISIBLE_DEVICES') os.environ['LOCAL_RANK'] = os.environ.get('RAY_LOCAL_RANK') ### @@ -144,13 +145,13 @@ class Worker(WorkerHelper): ### # [SUPPORT AMD: torch] - if "AMD" in torch.cuda.get_device_name(): + if "AMD" in device_name: self.local_rank = int(os.environ['LOCAL_RANK']) ### ### # [SUPPORT AMD: torch] - if "AMD" in torch.cuda.get_device_name(): + if "AMD" in device_name: cuda_visible_devices = str(local_rank) ### @@ -171,7 +172,7 @@ class Worker(WorkerHelper): ### # [SUPPORT AMD: torch] # torch.cuda.set_device(local_rank) - if "AMD" in torch.cuda.get_device_name(): + if "AMD" in device_name: torch.cuda.set_device(int(cuda_visible_devices)) ### diff --git a/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/single_controller/ray/base.py b/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/single_controller/ray/base.py index 19a530466daeed2c37419621f435aea267677e2c..381f8963df8fe32224ed9f04fff2d77228dd0d9c 100644 --- a/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/single_controller/ray/base.py +++ b/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/single_controller/ray/base.py @@ -81,16 +81,21 @@ class RayResourcePool(ResourcePool): self.pgs = None self.detached = detached - def get_placement_groups(self, strategy="STRICT_PACK", name=None): + def get_placement_groups(self, strategy="STRICT_PACK", device_name=None, name=None): if self.pgs is not None: return self.pgs pg_name_prefix = name if name else \ f"{self.name_prefix}verl_group_{'_'.join([str(count) for count in self._store])}:" # print(f"pg_name_prefix = {pg_name_prefix}") + # Non CUDA device compatibility + if device_name == "npu": + device_name = "NPU" + elif device_name == "cuda": + device_name = "GPU" pg_scheme = [[{ "CPU": self.max_collocate_count, - "GPU": 1 + device_name: 1 } if self.use_gpu else { "CPU": self.max_collocate_count } for _ in range(process_count)] for process_count in self._store] @@ -164,7 +169,8 @@ class RayClassWithInitArgs(ClassWithInitArgs): placement_group_bundle_idx, use_gpu: bool = True, num_gpus=1, - sharing_with=None) -> Any: + sharing_with=None, + device_name=None) -> Any: if sharing_with is not None: target_node_id = ray.get(sharing_with.get_node_id.remote()) cuda_visible_devices = ray.get(sharing_with.get_cuda_visible_devices.remote()) @@ -180,8 +186,10 @@ class RayClassWithInitArgs(ClassWithInitArgs): } options.update(self._options) - if use_gpu: + if use_gpu and device_name == "cuda": options["num_gpus"] = num_gpus + if use_gpu and device_name == "npu": + options["resources"] = {"NPU": num_gpus} if len(self._additional_resource) > 1: for k, v in self._additional_resource.items(): @@ -202,10 +210,12 @@ class RayWorkerGroup(WorkerGroup): name_prefix: str = None, detached=False, worker_names=None, + device_name=None, **kwargs) -> None: super().__init__(resource_pool=resource_pool, **kwargs) self.ray_cls_with_init = ray_cls_with_init self.name_prefix = get_random_string(length=6) if name_prefix is None else name_prefix + self.device_name = device_name if worker_names is not None: assert self._is_init_with_detached_workers @@ -237,7 +247,7 @@ class RayWorkerGroup(WorkerGroup): strategy = "PACK" if bin_pack: strategy = "STRICT_PACK" - pgs = resource_pool.get_placement_groups(strategy=strategy) + pgs = resource_pool.get_placement_groups(strategy=strategy, device_name=self.device_name) world_size = resource_pool.world_size self._world_size = world_size # cia.add_kwarg("_world_size", world_size) @@ -279,7 +289,8 @@ class RayWorkerGroup(WorkerGroup): worker = ray_cls_with_init(placement_group=pg, placement_group_bundle_idx=local_rank, use_gpu=use_gpu, - num_gpus=num_gpus) + num_gpus=num_gpus, + device_name=self.device_name) self._workers.append(worker) self._worker_names.append(name) diff --git a/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/trainer/config/ppo_trainer.yaml b/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/trainer/config/ppo_trainer.yaml index 5dfdb74433510c0ddc42b1c32b12606239949544..59e9b9db570856ec34a2692c913ee845a14c20d3 100644 --- a/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/trainer/config/ppo_trainer.yaml +++ b/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/trainer/config/ppo_trainer.yaml @@ -69,6 +69,7 @@ actor_rollout_ref: wrap_policy: # transformer_layer_cls_to_wrap: None min_num_params: 0 + use_torch_compile: ${actor_rollout_ref.actor.use_torch_compile} log_prob_micro_batch_size: null # will be deprecated, use log_prob_micro_batch_size_per_gpu log_prob_micro_batch_size_per_gpu: null log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz} diff --git a/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/trainer/fsdp_sft_trainer.py b/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/trainer/fsdp_sft_trainer.py index 2efdd69eae8f9fa349a7616f7f8aa7b5598857f9..b7a833a2219b212c1764dc7d10b5648cf11e8ac9 100644 --- a/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/trainer/fsdp_sft_trainer.py +++ b/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/trainer/fsdp_sft_trainer.py @@ -20,9 +20,6 @@ TODO(zhangchi.usc1992) import os -os.environ['NCCL_DEBUG'] = 'WARN' -os.environ['TOKENIZERS_PARALLELISM'] = 'true' - import logging import re from contextlib import nullcontext @@ -35,8 +32,8 @@ from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedModel, A from verl.utils.torch_functional import get_cosine_schedule_with_warmup from tensordict import TensorDict from torch.utils.data import DataLoader, DistributedSampler -from flash_attn.bert_padding import pad_input, unpad_input, rearrange, index_first_axis +from verl.utils.device import get_device_name, is_cuda_available, get_torch_device, is_npu_available from verl.utils.fsdp_utils import get_fsdp_wrap_policy, init_fn, get_init_weight_context_manager from verl.utils.dataset import SFTDataset from verl.utils.dataset.multiturn_sft_dataset import MultiTurnSFTDataset @@ -53,6 +50,15 @@ from verl.workers.sharding_manager import FSDPUlyssesShardingManager from verl.utils.ulysses import ulysses_pad_and_slice_inputs, gather_outpus_and_unpad from verl import DataProto +if is_cuda_available: + from flash_attn.bert_padding import pad_input, unpad_input, rearrange, index_first_axis +elif is_npu_available: + from transformers.integrations.npu_flash_attention import pad_input, unpad_input, rearrange, \ + index_first_axis + +os.environ['NCCL_DEBUG'] = 'WARN' +os.environ['TOKENIZERS_PARALLELISM'] = 'true' + logger = logging.getLogger(__file__) logger.setLevel(os.getenv('VERL_SFT_LOGGING_LEVEL', 'WARN')) @@ -107,6 +113,7 @@ class FSDPSFTTrainer(object): # TODO: add checkpoint manager if self.device_mesh.get_rank() == 0: print(self.config) + self.device_name = get_device_name() def _normalize_config_bsz(self): dp_size = self.device_mesh.size(0) if not self.ulysses_device_mesh else self.ulysses_device_mesh.size(0) @@ -260,7 +267,7 @@ class FSDPSFTTrainer(object): mixed_precision=mixed_precision, device_mesh=self.device_mesh, sync_module_states=True, - device_id=torch.cuda.current_device(), + device_id=get_torch_device().current_device(), cpu_offload=cpu_offload, use_orig_params=False) @@ -292,16 +299,16 @@ class FSDPSFTTrainer(object): use_sp = self.use_remove_padding and self.config.ulysses_sequence_parallel_size > 1 # Move inputs to GPU and prepare loss mask - input_ids = batch['input_ids'].cuda() - attention_mask = batch['attention_mask'].cuda() - position_ids = batch['position_ids'].cuda() - loss_mask = batch.pop('loss_mask')[:, :-1].reshape(-1).cuda() + input_ids = batch['input_ids'].to(self.device_name) + attention_mask = batch['attention_mask'].to(self.device_name) + position_ids = batch['position_ids'].to(self.device_name) + loss_mask = batch.pop('loss_mask')[:, :-1].reshape(-1).to(self.device_name) loss_fct = nn.CrossEntropyLoss(reduction='none') # Context manager for sequence parallel if needed context = self.sharding_manager if use_sp else nullcontext() with context: - with torch.autocast(device_type='cuda', dtype=torch.bfloat16): + with torch.autocast(device_type=self.device_name, dtype=torch.bfloat16): if not use_sp: # Standard forward pass without sequence parallel labels = input_ids[:, 1:].contiguous() @@ -420,15 +427,23 @@ class FSDPSFTTrainer(object): log_gpu_memory_usage('After offload weights', logger=logger) - step_loss = torch.tensor(step_loss).cuda() - torch.distributed.all_reduce(step_loss, op=torch.distributed.ReduceOp.AVG) + step_loss = torch.tensor(step_loss).to(self.device_name) + if is_cuda_available: + torch.distributed.all_reduce(step_loss, op=torch.distributed.ReduceOp.AVG) + elif is_npu_available: + torch.distributed.all_reduce(step_loss) + step_loss /= self.ulysses_device_mesh.size(0) return {'train/loss': step_loss.detach().item(), 'train/lr(1e-3)': lr * 1e3} def validation_step(self, batch: TensorDict): self.fsdp_model.eval() with torch.no_grad(): loss = self._compute_loss_and_backward(batch, do_backward=False) - torch.distributed.all_reduce(loss, op=torch.distributed.ReduceOp.AVG) + if is_cuda_available: + torch.distributed.all_reduce(loss, op=torch.distributed.ReduceOp.AVG) + elif is_npu_available: + torch.distributed.all_reduce(loss) + loss /= self.ulysses_device_mesh.size(0) return loss def save_checkpoint(self, step): @@ -477,7 +492,7 @@ class FSDPSFTTrainer(object): total=self.steps_per_epoch, desc=f"Epoch {epoch+1}/{self.config.trainer.total_epochs}"): global_step += 1 - data = TensorDict(data, batch_size=self.config.data.train_batch_size).cuda() + data = TensorDict(data, batch_size=self.config.data.train_batch_size).to(self.device_name) metric = self.training_step(data) if rank == 0: tracking.log(data=metric, step=global_step) @@ -487,7 +502,8 @@ class FSDPSFTTrainer(object): # Perform final validation val_losses = [] for val_data in self.val_dataloader: - val_data = TensorDict(val_data, batch_size=self.config.data.micro_batch_size_per_gpu).cuda() + val_data = TensorDict(val_data, + batch_size=self.config.data.micro_batch_size_per_gpu).to(self.device_name) val_loss = self.validation_step(val_data) val_losses.append(val_loss) if rank == 0: @@ -503,7 +519,7 @@ class FSDPSFTTrainer(object): # validation val_losses = [] for data in self.val_dataloader: - data = TensorDict(data, batch_size=self.config.data.micro_batch_size_per_gpu).cuda() + data = TensorDict(data, batch_size=self.config.data.micro_batch_size_per_gpu).to(self.device_name) val_loss = self.validation_step(data) val_losses.append(val_loss) if rank == 0: @@ -526,11 +542,12 @@ from verl.utils.distributed import initialize_global_process_group @hydra.main(config_path='config', config_name='sft_trainer', version_base=None) def main(config): + device_name = get_device_name() local_rank, rank, world_size = initialize_global_process_group() - device_mesh = init_device_mesh(device_type='cuda', mesh_shape=(world_size,), mesh_dim_names=('fsdp',)) + device_mesh = init_device_mesh(device_type=device_name, mesh_shape=(world_size,), mesh_dim_names=('fsdp',)) dp_size = world_size // config.ulysses_sequence_parallel_size - ulysses_device_mesh = init_device_mesh(device_type='cuda', + ulysses_device_mesh = init_device_mesh(device_type=device_name, mesh_shape=(dp_size, config.ulysses_sequence_parallel_size), mesh_dim_names=('dp', 'sp')) trainer = FSDPSFTTrainer(config=config, device_mesh=device_mesh, ulysses_device_mesh=ulysses_device_mesh) diff --git a/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/trainer/main_ppo.py b/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/trainer/main_ppo.py index 773f230aabb15e91aac9048fb36dca0237684e22..7e6d5a93c7622e6e26cb0cd13d5910bf0d28bfe5 100644 --- a/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/trainer/main_ppo.py +++ b/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/trainer/main_ppo.py @@ -19,6 +19,7 @@ from verl.trainer.ppo.ray_trainer import RayPPOTrainer import os import ray import hydra +from verl.utils.device import is_npu_available def get_custom_reward_fn(config): @@ -181,7 +182,8 @@ class TaskRunner: compute_score=compute_score, reward_fn_key=config.data.reward_fn_key) resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping) - + # Non CUDA device compatibility + device_name = "npu" if is_npu_available else "cuda" trainer = RayPPOTrainer(config=config, tokenizer=tokenizer, processor=processor, @@ -189,7 +191,8 @@ class TaskRunner: resource_pool_manager=resource_pool_manager, ray_worker_group_cls=ray_worker_group_cls, reward_fn=reward_fn, - val_reward_fn=val_reward_fn) + val_reward_fn=val_reward_fn, + device_name=device_name) trainer.init_workers() trainer.fit() diff --git a/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/trainer/ppo/ray_trainer.py b/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/trainer/ppo/ray_trainer.py index 047d6e36241bacbafdc7615b5be90ef07f86cf6c..0f28f3131535c973fa39f6af49ed9425b97018ce 100644 --- a/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/trainer/ppo/ray_trainer.py +++ b/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/trainer/ppo/ray_trainer.py @@ -108,7 +108,10 @@ class ResourcePoolManager: def _check_resource_available(self): """Check if the resource pool can be satisfied in this ray cluster.""" node_available_resources = ray.state.available_resources_per_node() - node_available_gpus = {node: node_info.get('GPU', 0) for node, node_info in node_available_resources.items()} + node_available_gpus = { + node: node_info.get('NPU', 0) if 'NPU' in node_info else node_info.get('GPU', 0) + for node, node_info in node_available_resources.items() + } # check total required gpus can be satisfied total_available_gpus = sum(node_available_gpus.values()) @@ -252,7 +255,8 @@ class RayPPOTrainer(object): ray_worker_group_cls: RayWorkerGroup = RayWorkerGroup, processor=None, reward_fn=None, - val_reward_fn=None): + val_reward_fn=None, + device_name=None): # assert torch.cuda.is_available(), 'cuda must be available on driver' @@ -273,6 +277,7 @@ class RayPPOTrainer(object): self.use_reference_policy = Role.RefPolicy in role_worker_mapping self.use_rm = Role.RewardModel in role_worker_mapping self.ray_worker_group_cls = ray_worker_group_cls + self.device_name = device_name self.validation_generations_logger = ValidationGenerationsLogger() # define in-reward KL control @@ -638,7 +643,9 @@ class RayPPOTrainer(object): if self.use_rm: # we create a RM here resource_pool = self.resource_pool_manager.get_resource_pool(Role.RewardModel) - rm_cls = RayClassWithInitArgs(self.role_worker_mapping[Role.RewardModel], config=self.config.reward_model) + rm_cls = RayClassWithInitArgs(self.role_worker_mapping[Role.RewardModel], + config=self.config.reward_model, + device_name=self.device_name) self.resource_pool_to_cls[resource_pool]['rm'] = rm_cls # initialize WorkerGroup @@ -649,7 +656,9 @@ class RayPPOTrainer(object): self.wg_dicts = [] for resource_pool, class_dict in self.resource_pool_to_cls.items(): worker_dict_cls = create_colocated_worker_cls(class_dict=class_dict) - wg_dict = self.ray_worker_group_cls(resource_pool=resource_pool, ray_cls_with_init=worker_dict_cls) + wg_dict = self.ray_worker_group_cls(resource_pool=resource_pool, + ray_cls_with_init=worker_dict_cls, + device_name=self.device_name) spawn_wg = wg_dict.spawn(prefix_set=class_dict.keys()) all_wg.update(spawn_wg) # keep the referece of WorkerDict to support ray >= 2.31. Ref: https://github.com/ray-project/ray/pull/45699 diff --git a/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/utils/checkpoint/fsdp_checkpoint_manager.py b/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/utils/checkpoint/fsdp_checkpoint_manager.py index c59f844dfeef365d3ebf7ec2829cf4a6f4cde29c..0cbba60433b5ac1619418316076ecd1882e07cc9 100644 --- a/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/utils/checkpoint/fsdp_checkpoint_manager.py +++ b/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/utils/checkpoint/fsdp_checkpoint_manager.py @@ -23,6 +23,7 @@ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, StateDictTy from torch.distributed.fsdp import ShardedStateDictConfig, ShardedOptimStateDictConfig from verl.utils.fs import copy_to_local, is_non_local +from verl.utils.device import is_cuda_available from transformers import PreTrainedTokenizer, ProcessorMixin @@ -96,8 +97,8 @@ class FSDPCheckpointManager(BaseCheckpointManager): lr_scheduler_state_dict = extra_state_dict['lr_scheduler'] - state_dict_cfg = ShardedStateDictConfig(offload_to_cpu=True) - optim_cfg = ShardedOptimStateDictConfig(offload_to_cpu=True) + state_dict_cfg = ShardedStateDictConfig(offload_to_cpu=True if is_cuda_available else False) + optim_cfg = ShardedOptimStateDictConfig(offload_to_cpu=True if is_cuda_available else False) with FSDP.state_dict_type(self.model, StateDictType.SHARDED_STATE_DICT, state_dict_cfg, optim_cfg): self.model.load_state_dict(model_state_dict) if self.optimizer is not None: @@ -128,8 +129,8 @@ class FSDPCheckpointManager(BaseCheckpointManager): torch.distributed.barrier() # every rank will save its own model and optim shard - state_dict_cfg = ShardedStateDictConfig(offload_to_cpu=True) - optim_cfg = ShardedOptimStateDictConfig(offload_to_cpu=True) + state_dict_cfg = ShardedStateDictConfig(offload_to_cpu=True if is_cuda_available else False) + optim_cfg = ShardedOptimStateDictConfig(offload_to_cpu=True if is_cuda_available else False) with warnings.catch_warnings(): warnings.simplefilter("ignore") with FSDP.state_dict_type(self.model, StateDictType.SHARDED_STATE_DICT, state_dict_cfg, optim_cfg): diff --git a/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/utils/debug/performance.py b/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/utils/debug/performance.py index 615475a66a5e45853540df2efd09c25991b43e12..6b6bbd4ef8971fd4a75d877075b3690f776eb84b 100644 --- a/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/utils/debug/performance.py +++ b/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/utils/debug/performance.py @@ -15,12 +15,13 @@ import torch import torch.distributed as dist import logging +from verl.utils.device import get_torch_device def log_gpu_memory_usage(head: str, logger: logging.Logger = None, level=logging.DEBUG, rank: int = 0): if (not dist.is_initialized()) or (rank is None) or (dist.get_rank() == rank): - memory_allocated = torch.cuda.memory_allocated() / 1024**3 - memory_reserved = torch.cuda.memory_reserved() / 1024**3 + memory_allocated = get_torch_device().memory_allocated() / 1024**3 + memory_reserved = get_torch_device().memory_reserved() / 1024**3 message = f'{head}, memory allocated (GB): {memory_allocated}, memory reserved (GB): {memory_reserved}' diff --git a/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/utils/device.py b/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/utils/device.py new file mode 100644 index 0000000000000000000000000000000000000000..22b7cbee2ba3532d381278a2eb6e8bb343fd19a3 --- /dev/null +++ b/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/utils/device.py @@ -0,0 +1,82 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import Optional + +import torch + +logger = logging.getLogger(__name__) + + +def is_torch_npu_available() -> bool: + """Check the availability of NPU""" + try: + import torch_npu # noqa: F401 + + return torch.npu.is_available() + except ImportError: + return False + + +is_cuda_available = torch.cuda.is_available() +is_npu_available = is_torch_npu_available() + + +def get_device_name() -> str: + """Function that gets the torch.device based on the current machine. + This currently only supports CPU, CUDA, NPU. + Returns: + device + """ + if is_cuda_available: + device = "cuda" + elif is_npu_available: + device = "npu" + else: + device = "cpu" + return device + + +def get_device(device_name: Optional[str] = None) -> torch.device: + """Function that takes an optional device string, verifies it's correct and available given the machine and + distributed settings, and returns a :func:`~torch.device`. If device string is not provided, this function will + infer the device based on the environment. + If CUDA-like is available and being used, this function also sets the CUDA-like device. + Args: + device (Optional[str]): The name of the device to use, e.g. "cuda" or "cpu" or "npu". + Example: + >>> device = get_device("cuda") + >>> device + device(type='cuda', index=0) + Returns: + torch.device: Device + """ + if device_name is None: + device_name = get_device_name() + device = torch.device(device_name) + return device + + +def get_torch_device() -> any: + """Return the corresponding torch attribute based on the device type string. + Returns: + module: The corresponding torch device namespace, or torch.cuda if not found. + """ + device_name = get_device_name() + try: + return getattr(torch, device_name) + except AttributeError: + logger.warning(f"Device namespace '{device_name}' not found in torch, try to load torch.cuda.") + return torch.cuda \ No newline at end of file diff --git a/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/utils/distributed.py b/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/utils/distributed.py index 6fea5a29cd943ef91c8f27f44db2a69e40702cf7..af859793405306bf0dfa7bcb73683d382baf073f 100644 --- a/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/utils/distributed.py +++ b/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/utils/distributed.py @@ -13,16 +13,18 @@ # limitations under the License. """Utilities for distributed training.""" import os +from verl.utils.device import is_cuda_available, get_torch_device def initialize_global_process_group(timeout_second=36000): import torch.distributed from datetime import timedelta - torch.distributed.init_process_group('nccl', timeout=timedelta(seconds=timeout_second)) + torch.distributed.init_process_group('nccl' if is_cuda_available else 'hccl', + timeout=timedelta(seconds=timeout_second)) local_rank = int(os.environ["LOCAL_RANK"]) rank = int(os.environ["RANK"]) world_size = int(os.environ["WORLD_SIZE"]) if torch.distributed.is_initialized(): - torch.cuda.set_device(local_rank) + get_torch_device().set_device(local_rank) return local_rank, rank, world_size diff --git a/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/utils/flops_counter.py b/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/utils/flops_counter.py index 9bcebc85189ad4a20b8d99bfc18fa66dc1d0970c..795fb18366f20d70a16c7398bae7f57b4cb28ac4 100644 --- a/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/utils/flops_counter.py +++ b/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/utils/flops_counter.py @@ -14,6 +14,7 @@ import torch from transformers import PretrainedConfig +from verl.utils.device import get_torch_device VALID_CONFIG_TYPE = {"llama", "qwen2", "qwen2_vl", "qwen2_5_vl", "deepseek_v3"} @@ -30,7 +31,7 @@ def get_device_flops(unit="T"): ptr += 1 return number - device_name = torch.cuda.get_device_name() + device_name = get_torch_device().get_device_name() flops = float("inf") # INF flops for unkown gpu type if "MI300X" in device_name: diff --git a/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/utils/fsdp_utils.py b/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/utils/fsdp_utils.py index b3f5b73534ea5eb9e9f961a0280893e93ae2d62b..0fa684e1a95f06b27238ecdf14a8f92201f274b4 100644 --- a/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/utils/fsdp_utils.py +++ b/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/utils/fsdp_utils.py @@ -27,12 +27,13 @@ from transformers.trainer_pt_utils import get_module_class_from_name import torch import torch.nn as nn import torch.distributed as dist +from verl.utils.device import get_torch_device def init_fn(x: torch.nn.Module): if not torch.distributed.get_rank() == 0: - x = x.to_empty(device=torch.cuda.current_device(), recurse=False) - torch.cuda.empty_cache() + x = x.to_empty(device=get_torch_device().current_device(), recurse=False) + get_torch_device().empty_cache() return x @@ -129,7 +130,7 @@ def offload_fsdp_model_to_cpu(model: FSDP, empty_cache: bool = True): flat_param._local_shard = flat_param.data assert id(flat_param._local_shard) != id(flat_param.data) if empty_cache: - torch.cuda.empty_cache() + get_torch_device().empty_cache() @torch.no_grad() @@ -138,7 +139,7 @@ def load_fsdp_model_to_gpu(model: FSDP): # lazy init FSDP model _lazy_init(model, model) assert model._is_root, f"Only support root model loading to GPU" - device_id = torch.cuda.current_device() + device_id = get_torch_device().current_device() for handle in model._all_handles: if handle._offload_params: continue @@ -245,7 +246,7 @@ def parallel_load_safetensors(filepath): ckpt_chunks = [ckpt_chunks[rank * size:rank * size + size] for rank in range(world_size)] shard_states = {} - device = torch.cuda.current_device() + device = get_torch_device().current_device() for rank, files in enumerate(ckpt_chunks): if rank == dist.get_rank(): for file in files: @@ -284,7 +285,7 @@ def parallel_init_module_fn(module: torch.nn.Module, shard_states: Dict[str, tor @torch.no_grad() def create_and_sync_state(param_name, state, is_param): assert param_name in shard_states, f"{param_name} not loaded" - device = torch.cuda.current_device() + device = get_torch_device().current_device() if is_param: param = torch.nn.Parameter(torch.empty_like(state.data, device=device), requires_grad=state.requires_grad) else: # buffer diff --git a/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/utils/npu_patch.py b/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/utils/npu_patch.py new file mode 100644 index 0000000000000000000000000000000000000000..302e47041901c7f0d453b8746b585ee748ad155d --- /dev/null +++ b/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/utils/npu_patch.py @@ -0,0 +1,78 @@ +# Copyright 2025 The Qwen Team and The HuggingFace Inc. team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Tuple + +import torch +import torch.nn.functional as F +from torch_npu import npu_rotary_mul as apply_rotary_emb +from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import apply_rotary_pos_emb_vision, \ + Qwen2_5_VLVisionSdpaAttention +from transformers.models.qwen2_5_vl import modeling_qwen2_5_vl +from transformers.utils import logging + + +logger = logging.get_logger(__name__) + + +def apply_rotary_pos_emb_flashatt_npu( + q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor +) -> Tuple[torch.Tensor, torch.Tensor]: + q_embed = apply_rotary_emb(q.float(), cos.unsqueeze(0).unsqueeze(2).float(), sin.unsqueeze(0).unsqueeze(2).float()).type_as(q) + k_embed = apply_rotary_emb(k.float(), cos.unsqueeze(0).unsqueeze(2).float(), sin.unsqueeze(0).unsqueeze(2).float()).type_as(k) + return q_embed, k_embed + + +def sdpa_forward( + self, + hidden_states: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb: Optional[torch.Tensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, +) -> torch.Tensor: + seq_length = hidden_states.shape[0] + q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `rotary_pos_emb` (2D tensor of RoPE theta values), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.54 `rotary_pos_emb` will be " + "removed and `position_embeddings` will be mandatory." + ) + emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) + cos = emb.cos() + sin = emb.sin() + else: + cos, sin = position_embeddings + q, k = apply_rotary_pos_emb_vision(q, k, cos, sin) + + attention_mask = torch.zeros([1, seq_length, seq_length], device=q.device, dtype=torch.bool) + for i in range(1, len(cu_seqlens)): + attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = True + q = q.transpose(0, 1) + k = k.transpose(0, 1) + v = v.transpose(0, 1) + # scaled_dot_product_attention on ASCEND NPU only supports the case where q/k/v/attention_mask is both + # 4-dimensional, so attention_mask can be changed to 4-dimensional through unsqueeze + attn_output = F.scaled_dot_product_attention( + q.unsqueeze(0), k.unsqueeze(0), v.unsqueeze(0), attention_mask.unsqueeze(0), dropout_p=0.0 + ) + attn_output = attn_output.squeeze(0).transpose(0, 1) + attn_output = attn_output.reshape(seq_length, -1) + attn_output = self.proj(attn_output) + return attn_output + + +Qwen2_5_VLVisionSdpaAttention.forward = sdpa_forward +modeling_qwen2_5_vl.apply_rotary_pos_emb_flashatt = apply_rotary_pos_emb_flashatt_npu diff --git a/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/workers/actor/dp_actor.py b/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/workers/actor/dp_actor.py index a3ae5e6f00538216ac74a2139c89027d482b836a..908ba233d5383fd8313d00df4d3a88210ead339c 100644 --- a/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/workers/actor/dp_actor.py +++ b/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/workers/actor/dp_actor.py @@ -31,7 +31,13 @@ from verl.utils.ulysses import ulysses_pad_and_slice_inputs, gather_outpus_and_u from verl.utils.seqlen_balancing import rearrange_micro_batches, get_reverse_idx import verl.utils.torch_functional as verl_F -from flash_attn.bert_padding import pad_input, unpad_input, rearrange, index_first_axis +from verl.utils.device import get_device_name, get_torch_device, is_cuda_available, is_npu_available + +if is_cuda_available: + from flash_attn.bert_padding import pad_input, unpad_input, rearrange, index_first_axis +elif is_npu_available: + from transformers.integrations.npu_flash_attention import pad_input, unpad_input, rearrange, \ + index_first_axis __all__ = ['DataParallelPPOActor'] @@ -57,6 +63,7 @@ class DataParallelPPOActor(BasePPOActor): torch.compile(verl_F.entropy_from_logits, dynamic=True) if self.config.get('use_torch_compile', True) # use torch compile by default else verl_F.entropy_from_logits) + self.device_name = get_device_name() def _forward_micro_batch(self, micro_batch, temperature) -> Tuple[torch.Tensor, torch.Tensor]: """ @@ -71,7 +78,7 @@ class DataParallelPPOActor(BasePPOActor): multi_modal_inputs[key] = torch.cat([inputs[key] for inputs in micro_batch['multi_modal_inputs']], dim=0) - with torch.autocast(device_type='cuda', dtype=torch.bfloat16): + with torch.autocast(device_type=self.device_name, dtype=torch.bfloat16): input_ids = micro_batch['input_ids'] batch_size, seqlen = input_ids.shape attention_mask = micro_batch['attention_mask'] @@ -275,9 +282,11 @@ class DataParallelPPOActor(BasePPOActor): for data in micro_batches: # Support all hardwares if isinstance(data, DataProto): - data = {**data.batch.to(torch.cuda.current_device()), **data.non_tensor_batch} + data = { + **data.batch.to(get_torch_device().current_device()), **data.non_tensor_batch + } else: - data = data.to(torch.cuda.current_device()) # actor device is cpu when using offload + data = data.to(get_torch_device().current_device()) # actor device is cpu when using offload responses = data['responses'] response_length = responses.size(1) attention_mask = data['attention_mask'] diff --git a/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/workers/critic/dp_critic.py b/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/workers/critic/dp_critic.py index d100425b3cf25f46f822a8687f77e52b1bfc8956..95cef04c98152865afd2c645e16812665708fdcc 100644 --- a/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/workers/critic/dp_critic.py +++ b/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/workers/critic/dp_critic.py @@ -31,7 +31,13 @@ from verl.utils.torch_functional import masked_mean from verl.utils.ulysses import ulysses_pad_and_slice_inputs, gather_outpus_and_unpad from verl.utils.seqlen_balancing import rearrange_micro_batches, get_reverse_idx -from flash_attn.bert_padding import pad_input, unpad_input, rearrange, index_first_axis +from verl.utils.device import get_device_name, get_torch_device, is_npu_available, is_cuda_available + +if is_cuda_available: + from flash_attn.bert_padding import pad_input, unpad_input, rearrange, index_first_axis +elif is_npu_available: + from transformers.integrations.npu_flash_attention import pad_input, unpad_input, rearrange, \ + index_first_axis __all__ = ['DataParallelPPOCritic'] @@ -46,6 +52,7 @@ class DataParallelPPOCritic(BasePPOCritic): print(f'Critic use_remove_padding={self.use_remove_padding}') self.ulysses_sequence_parallel_size = self.config.get('ulysses_sequence_parallel_size', 1) + self.device_name = get_device_name() def _forward_micro_batch(self, micro_batch): response_length = micro_batch['responses'].size(-1) @@ -55,7 +62,7 @@ class DataParallelPPOCritic(BasePPOCritic): multi_modal_inputs[key] = torch.cat([inputs[key] for inputs in micro_batch['multi_modal_inputs']], dim=0) - with torch.autocast(device_type='cuda', dtype=torch.bfloat16): + with torch.autocast(device_type=self.device_name, dtype=torch.bfloat16): input_ids = micro_batch['input_ids'] batch, seqlen = input_ids.shape attention_mask = micro_batch['attention_mask'] @@ -206,9 +213,11 @@ class DataParallelPPOCritic(BasePPOCritic): for data in micro_batches: #Support all devices if isinstance(data, DataProto): - data = {**data.batch.to(torch.cuda.current_device()), **data.non_tensor_batch} + data = { + **data.batch.to(get_torch_device().current_device()), **data.non_tensor_batch + } else: - data = data.to(torch.cuda.current_device()) # critic device is cpu when using offload + data = data.to(get_torch_device().current_device()) # critic device is cpu when using offload input_ids = data['input_ids'] responses = data['responses'] attention_mask = data['attention_mask'] diff --git a/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/workers/fsdp_workers.py b/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/workers/fsdp_workers.py index ca5578c08cd62f08eff3a940a530cd0d702ab3ff..34a4b893d17b61b12daf4207b17d6f98684fab63 100644 --- a/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/workers/fsdp_workers.py +++ b/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/workers/fsdp_workers.py @@ -39,18 +39,21 @@ from verl.utils.model import compute_position_id_with_mask from verl.utils.flops_counter import FlopsCounter from verl.utils.checkpoint.fsdp_checkpoint_manager import FSDPCheckpointManager from verl.workers.sharding_manager.fsdp_ulysses import FSDPUlyssesShardingManager +from verl.utils.device import get_device_name, is_cuda_available, get_torch_device, is_npu_available from codetiming import Timer logger = logging.getLogger(__file__) logger.setLevel(os.getenv('VERL_PPO_LOGGING_LEVEL', 'WARN')) +device_name = get_device_name() + def create_device_mesh(world_size, fsdp_size): if fsdp_size < 0 or fsdp_size >= world_size: - device_mesh = init_device_mesh('cuda', mesh_shape=(world_size,), mesh_dim_names=['fsdp']) + device_mesh = init_device_mesh(device_name, mesh_shape=(world_size,), mesh_dim_names=['fsdp']) else: - device_mesh = init_device_mesh('cuda', + device_mesh = init_device_mesh(device_name, mesh_shape=(world_size // fsdp_size, fsdp_size), mesh_dim_names=['ddp', 'fsdp']) return device_mesh @@ -78,7 +81,7 @@ class ActorRolloutRefWorker(Worker): self.config = config import torch.distributed if not torch.distributed.is_initialized(): - torch.distributed.init_process_group() + torch.distributed.init_process_group(backend="nccl" if is_cuda_available else "hccl") # build device mesh for FSDP world_size = torch.distributed.get_world_size() @@ -90,7 +93,7 @@ class ActorRolloutRefWorker(Worker): self.ulysses_sequence_parallel_size = self.config.actor.get('ulysses_sequence_parallel_size', 1) dp = world_size // self.ulysses_sequence_parallel_size if self.ulysses_sequence_parallel_size > 1: - self.ulysses_device_mesh = init_device_mesh('cuda', + self.ulysses_device_mesh = init_device_mesh(device_name, mesh_shape=(dp, self.ulysses_sequence_parallel_size), mesh_dim_names=['dp', 'sp']) @@ -257,7 +260,7 @@ class ActorRolloutRefWorker(Worker): param_init_fn=init_fn, use_orig_params=False, auto_wrap_policy=auto_wrap_policy, - device_id=torch.cuda.current_device(), + device_id=get_torch_device().current_device(), sharding_strategy=sharding_strategy, # zero3 mixed_precision=mixed_precision, sync_module_states=True, @@ -298,7 +301,7 @@ class ActorRolloutRefWorker(Worker): infer_tp = self.config.rollout.tensor_model_parallel_size dp = self.world_size // infer_tp assert self.world_size % infer_tp == 0, f'rollout world_size: {self.world_size} is not divisible by infer_tp: {infer_tp}' - rollout_device_mesh = init_device_mesh('cuda', mesh_shape=(dp, infer_tp), mesh_dim_names=['dp', 'infer_tp']) + rollout_device_mesh = init_device_mesh(device_name, mesh_shape=(dp, infer_tp), mesh_dim_names=['dp', 'infer_tp']) rollout_name = self.config.rollout.name if rollout_name == 'hf': from verl.workers.rollout import HFRollout @@ -438,13 +441,13 @@ class ActorRolloutRefWorker(Worker): @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) def update_actor(self, data: DataProto): # Support all hardwares - data = data.to(torch.cuda.current_device()) + data = data.to(get_torch_device().current_device()) assert self._is_actor if self._is_offload_param: load_fsdp_model_to_gpu(self.actor_module_fsdp) if self._is_offload_optimizer: - load_fsdp_optimizer(optimizer=self.actor_optimizer, device_id=torch.cuda.current_device()) + load_fsdp_optimizer(optimizer=self.actor_optimizer, device_id=get_torch_device().current_device()) log_gpu_memory_usage('Before update policy', logger=logger) @@ -484,7 +487,7 @@ class ActorRolloutRefWorker(Worker): @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) def generate_sequences(self, prompts: DataProto): # Support all hardwares - prompts = prompts.to(torch.cuda.current_device()) + prompts = prompts.to(get_torch_device().current_device()) assert self._is_rollout if self._is_offload_param: @@ -528,7 +531,7 @@ class ActorRolloutRefWorker(Worker): load_fsdp_model_to_gpu(self.actor_module_fsdp) # Support all hardwares - data = data.to(torch.cuda.current_device()) + data = data.to(get_torch_device().current_device()) # we should always recompute old_log_probs when it is HybridEngine data.meta_info['micro_batch_size'] = self.config.rollout.log_prob_micro_batch_size_per_gpu data.meta_info['max_token_len'] = self.config.rollout.log_prob_max_token_len_per_gpu @@ -560,7 +563,7 @@ class ActorRolloutRefWorker(Worker): assert self._is_ref # Support all hardwares - data = data.to(torch.cuda.current_device()) + data = data.to(get_torch_device().current_device()) micro_batch_size = self.config.ref.log_prob_micro_batch_size_per_gpu data.meta_info['micro_batch_size'] = micro_batch_size @@ -621,7 +624,7 @@ class CriticWorker(Worker): super().__init__() import torch.distributed if not torch.distributed.is_initialized(): - torch.distributed.init_process_group(backend="nccl") + torch.distributed.init_process_group(backend="nccl" if is_cuda_available else "hccl") self.config = config # build device mesh for Ulysses Sequence Parallel @@ -635,7 +638,7 @@ class CriticWorker(Worker): self.ulysses_sequence_parallel_size = self.config.get('ulysses_sequence_parallel_size', 1) dp = world_size // self.ulysses_sequence_parallel_size if self.ulysses_sequence_parallel_size > 1: - self.ulysses_device_mesh = init_device_mesh('cuda', + self.ulysses_device_mesh = init_device_mesh(device_name, mesh_shape=(dp, self.ulysses_sequence_parallel_size), mesh_dim_names=['dp', 'sp']) @@ -749,7 +752,7 @@ class CriticWorker(Worker): param_init_fn=init_fn, use_orig_params=False, auto_wrap_policy=auto_wrap_policy, - device_id=torch.cuda.current_device(), + device_id=get_torch_device().current_device(), sharding_strategy=sharding_strategy, mixed_precision=mixed_precision, sync_module_states=True, @@ -808,7 +811,7 @@ class CriticWorker(Worker): def compute_values(self, data: DataProto): # Support all hardwares - data = data.to(torch.cuda.current_device()) + data = data.to(get_torch_device().current_device()) if self._is_offload_param: load_fsdp_model_to_gpu(self.critic_module) @@ -831,11 +834,11 @@ class CriticWorker(Worker): @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) def update_critic(self, data: DataProto): # Support all hardwares - data = data.to(torch.cuda.current_device()) + data = data.to(get_torch_device().current_device()) if self._is_offload_param: load_fsdp_model_to_gpu(self.critic_module) if self._is_offload_optimizer: - load_fsdp_optimizer(optimizer=self.critic_optimizer, device_id=torch.cuda.current_device()) + load_fsdp_optimizer(optimizer=self.critic_optimizer, device_id=get_torch_device().current_device()) # perform forward computation with self.ulysses_sharding_manager: @@ -907,7 +910,7 @@ class RewardModelWorker(Worker): super().__init__() import torch.distributed if not torch.distributed.is_initialized(): - torch.distributed.init_process_group(backend="nccl") + torch.distributed.init_process_group(backend="nccl" if is_cuda_available else "hccl") self.config = config # build device mesh for Ulysses Sequence Parallel @@ -921,7 +924,7 @@ class RewardModelWorker(Worker): self.ulysses_sequence_parallel_size = self.config.get('ulysses_sequence_parallel_size', 1) dp = world_size // self.ulysses_sequence_parallel_size if self.ulysses_sequence_parallel_size > 1: - self.ulysses_device_mesh = init_device_mesh('cuda', + self.ulysses_device_mesh = init_device_mesh(device_name, mesh_shape=(dp, self.ulysses_sequence_parallel_size), mesh_dim_names=['dp', 'sp']) @@ -984,7 +987,7 @@ class RewardModelWorker(Worker): param_init_fn=init_fn, use_orig_params=False, auto_wrap_policy=auto_wrap_policy, - device_id=torch.cuda.current_device(), + device_id=get_torch_device().current_device(), sharding_strategy=sharding_strategy, # zero3 sync_module_states=True, cpu_offload=CPUOffload(offload_params=True), @@ -1000,10 +1003,14 @@ class RewardModelWorker(Worker): self.reward_module = self._build_model(config=self.config) def _forward_micro_batch(self, micro_batch): - from flash_attn.bert_padding import pad_input, unpad_input, index_first_axis, rearrange + if is_cuda_available: + from flash_attn.bert_padding import pad_input, unpad_input, rearrange, index_first_axis + elif is_npu_available: + from transformers.integrations.npu_flash_attention import pad_input, unpad_input, rearrange, \ + index_first_axis from verl.utils.ulysses import ulysses_pad_and_slice_inputs, gather_outpus_and_unpad - with torch.no_grad(), torch.autocast(device_type='cuda', dtype=torch.bfloat16): + with torch.no_grad(), torch.autocast(device_type=device_name, dtype=torch.bfloat16): input_ids = micro_batch['input_ids'] batch_size, seqlen = input_ids.shape attention_mask = micro_batch['attention_mask'] @@ -1131,7 +1138,7 @@ class RewardModelWorker(Worker): import itertools from verl.utils.seqlen_balancing import rearrange_micro_batches, get_reverse_idx # Support all hardwares - data = data.to(torch.cuda.current_device()) + data = data.to(get_torch_device().current_device()) if self._do_switch_chat_template: rm_data = self._switch_chat_template(data) else: @@ -1146,7 +1153,7 @@ class RewardModelWorker(Worker): rm_data = DataProto.from_dict(rm_inputs) # Support all hardwares - rm_data.batch = rm_data.batch.to(torch.cuda.current_device()) + rm_data.batch = rm_data.batch.to(get_torch_device().current_device()) # perform forward computation with self.ulysses_sharding_manager: diff --git a/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/workers/rollout/vllm_rollout/__init__.py b/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/workers/rollout/vllm_rollout/__init__.py index 0d6d4c3d818b73e33a9403bc698c917fbe8aaef3..39a883c4a1d3a5c54ec8d968d9666557f625f44e 100644 --- a/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/workers/rollout/vllm_rollout/__init__.py +++ b/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/workers/rollout/vllm_rollout/__init__.py @@ -13,6 +13,7 @@ # limitations under the License. from importlib.metadata import version, PackageNotFoundError +from verl.utils.device import get_device_name ### # [SUPPORT AMD:] @@ -29,11 +30,12 @@ def get_version(pkg): package_name = 'vllm' package_version = get_version(package_name) +device_name = get_device_name() ### # package_version = get_version(package_name) # [SUPPORT AMD:] -if "AMD" in torch.cuda.get_device_name(): +if "AMD" in device_name: import re package_version = version(package_name) package_version = re.match(r'(\d+\.\d+\.?\d*)', package_version).group(1) diff --git a/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/workers/sharding_manager/fsdp_vllm.py b/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/workers/sharding_manager/fsdp_vllm.py index ca990ea5748398527da7bfb1ef44cbee3938a492..a1bdfc460671082dce648075d3dfc8111b7c4415 100644 --- a/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/workers/sharding_manager/fsdp_vllm.py +++ b/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/workers/sharding_manager/fsdp_vllm.py @@ -27,6 +27,7 @@ from verl.utils.torch_functional import (broadcast_dict_tensor, allgather_dict_t from verl.protocol import all_gather_data_proto from verl.utils.debug import log_gpu_memory_usage from verl.third_party.vllm import vllm_version +from verl.utils.device import get_torch_device from .base import BaseShardingManager from .patch import patched_ds_v3_load_weights @@ -63,13 +64,13 @@ class FSDPVLLMShardingManager(BaseShardingManager): self.tp_rank = vllm_ps.get_tensor_model_parallel_rank() # Note that torch_random_states may be different on each dp rank - self.torch_random_states = torch.cuda.get_rng_state() + self.torch_random_states = get_torch_device().get_rng_state() # get a random rng states if self.device_mesh is not None: gen_dp_rank = self.device_mesh['dp'].get_local_rank() - torch.cuda.manual_seed(gen_dp_rank + 1000) # make sure all tp ranks have the same random states - self.gen_random_states = torch.cuda.get_rng_state() - torch.cuda.set_rng_state(self.torch_random_states) + get_torch_device().manual_seed(gen_dp_rank + 1000) # make sure all tp ranks have the same random states + self.gen_random_states = get_torch_device().get_rng_state() + get_torch_device().set_rng_state(self.torch_random_states) else: self.gen_random_states = None @@ -117,8 +118,8 @@ class FSDPVLLMShardingManager(BaseShardingManager): # important: need to manually set the random states of each tp to be identical. if self.device_mesh is not None: - self.torch_random_states = torch.cuda.get_rng_state() - torch.cuda.set_rng_state(self.gen_random_states) + self.torch_random_states = get_torch_device().get_rng_state() + get_torch_device().set_rng_state(self.gen_random_states) def __exit__(self, exc_type, exc_value, traceback): log_gpu_memory_usage('Before vllm offload in sharding manager', logger=logger) @@ -136,12 +137,12 @@ class FSDPVLLMShardingManager(BaseShardingManager): self.module.train() # add empty cache after each compute - torch.cuda.empty_cache() + get_torch_device().empty_cache() # restore random states if self.device_mesh is not None: - self.gen_random_states = torch.cuda.get_rng_state() - torch.cuda.set_rng_state(self.torch_random_states) + self.gen_random_states = get_torch_device().get_rng_state() + get_torch_device().set_rng_state(self.torch_random_states) def preprocess_data(self, data: DataProto) -> DataProto: """All gather across tp group to make each rank has identical input."""