From 67cb07f6953fa57c4a12bfc83590facc7b63e8b9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E5=87=AF=E5=AE=87?= Date: Fri, 9 May 2025 16:24:44 +0800 Subject: [PATCH] [build-in][PyTorch][OpenRLHF] add OpenRLHF-v0.5.7 adaptation code --- .../openrlhf/cli/__init__.py | 5 + .../openrlhf/cli/train_ppo.py | 2 +- .../openrlhf/cli/train_ppo_ray.py | 9 +- .../openrlhf/cli/train_sft.py | 6 - .../openrlhf/datasets/sft_dataset.py | 57 +- .../openrlhf/models/model.py | 12 +- .../openrlhf/trainer/ray/ppo_actor.py | 44 +- .../openrlhf/trainer/ray/vllm_engine.py | 20 +- .../openrlhf/trainer/ray/vllm_worker_wrap.py | 33 +- .../openrlhf/utils/vision_utils.py | 703 ++++++++++++++++++ .../requirements.txt | 10 +- .../OpenRLHF_v0.5.7_for_PyTorch/version.txt | 2 +- 12 files changed, 757 insertions(+), 146 deletions(-) create mode 100644 PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/utils/vision_utils.py diff --git a/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/cli/__init__.py b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/cli/__init__.py index e69de29bb2..764afa7e68 100644 --- a/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/cli/__init__.py +++ b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/cli/__init__.py @@ -0,0 +1,5 @@ +from transformers import is_torch_npu_available + +if is_torch_npu_available(): + import torch_npu + from torch_npu.contrib import transfer_to_npu diff --git a/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/cli/train_ppo.py b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/cli/train_ppo.py index 2fc03bac05..c21009b855 100644 --- a/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/cli/train_ppo.py +++ b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/cli/train_ppo.py @@ -311,7 +311,7 @@ if __name__ == "__main__": parser.add_argument("--ptx_coef", type=float, default=0.05, help="PPO-ptx loss coef") parser.add_argument("--eps_clip", type=float, default=0.2, help="PPO clip range") parser.add_argument("--value_clip", type=float, default=0.2, help="PPO value clip range") - parser.add_argument("--lambd", type=float, default=1.0, help="PPO GAE lambd") + parser.add_argument("--lambd", type=float, default=0.95, help="PPO GAE lambd") parser.add_argument("--gamma", type=float, default=1, help="PPO GAE gamma") parser.add_argument("--micro_train_batch_size", type=int, default=4, help="batch size per GPU") parser.add_argument("--train_batch_size", type=int, default=128, help="Global training batch size") diff --git a/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/cli/train_ppo_ray.py b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/cli/train_ppo_ray.py index 245b855314..9bbd14a5c4 100644 --- a/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/cli/train_ppo_ray.py +++ b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/cli/train_ppo_ray.py @@ -213,7 +213,6 @@ if __name__ == "__main__": help="tensor parallel size of vLLM Engine for multi-GPU inference", ) parser.add_argument("--vllm_sync_backend", type=str, default="nccl", help="DeepSpeed -> vLLM weight sync backend") - parser.add_argument("--vllm_sync_with_ray", action="store_true", default=False) parser.add_argument("--enable_prefix_caching", action="store_true", default=False) parser.add_argument("--enforce_eager", action="store_true", default=False, help="Disable CUDA graph in vLLM") @@ -269,7 +268,7 @@ if __name__ == "__main__": parser.add_argument("--ptx_coef", type=float, default=0.05, help="PPO-ptx loss coef") parser.add_argument("--eps_clip", type=float, default=0.2, help="PPO clip range") parser.add_argument("--value_clip", type=float, default=0.2, help="PPO value clip range") - parser.add_argument("--lambd", type=float, default=1.0, help="PPO GAE lambd") + parser.add_argument("--lambd", type=float, default=0.95, help="PPO GAE lambd") parser.add_argument("--gamma", type=float, default=1, help="PPO GAE gamma") parser.add_argument("--micro_train_batch_size", type=int, default=4, help="batch size per GPU") parser.add_argument("--train_batch_size", type=int, default=128, help="Global training batch size") @@ -375,10 +374,8 @@ if __name__ == "__main__": args.remote_rm_url = args.remote_rm_url.split(",") if args.vllm_num_engines >= 1 and args.enable_prefix_caching: - import vllm - if vllm.__version__ < "0.7.0": - args.enable_prefix_caching = False - print("[Warning] Disable prefix cache because vLLM updates weights without updating the old KV Cache for vLLM version below 0.7.0.") + args.enable_prefix_caching = False + print("[Warning] Disable prefix cache because vLLM updates weights without updating the old KV Cache.") if args.input_template and "{}" not in args.input_template: print("[Warning] {} not in args.input_template, set to None") diff --git a/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/cli/train_sft.py b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/cli/train_sft.py index 843e37adad..ad3b2af98c 100644 --- a/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/cli/train_sft.py +++ b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/cli/train_sft.py @@ -63,7 +63,6 @@ def train(args): pretrain_mode=args.pretrain_mode, input_template=args.input_template, multiple_of=args.ring_attn_size, - multiturn=args.multiturn, ) eval_dataset = SFTDataset( eval_data, @@ -73,7 +72,6 @@ def train(args): pretrain_mode=args.pretrain_mode, input_template=args.input_template, multiple_of=args.ring_attn_size, - multiturn=args.multiturn, ) # prepare dataloader @@ -207,7 +205,6 @@ if __name__ == "__main__": parser.add_argument("--dataset_probs", type=str, default="1.0", help="sampling probs for datasets") parser.add_argument("--train_split", type=str, default="train", help="train split of the HF dataset") parser.add_argument("--eval_split", type=str, default="test", help="test split of the dataset") - parser.add_argument("--multiturn", action="store_true", default=False, help="Use compacted multiturn dataset") parser.add_argument("--input_key", type=str, default="input", help="JSON dataset key") parser.add_argument("--output_key", type=str, default=None, help="JSON dataset key") @@ -235,9 +232,6 @@ if __name__ == "__main__": args = parser.parse_args() - if args.multiturn: - assert args.apply_chat_template, "apply_chat_template must be enabled when using multiturn format" - if args.input_template and "{}" not in args.input_template: print("[Warning] {} not in args.input_template, set to None") args.input_template = None diff --git a/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/datasets/sft_dataset.py b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/datasets/sft_dataset.py index 6e031f70ab..e5e0c004e9 100644 --- a/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/datasets/sft_dataset.py +++ b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/datasets/sft_dataset.py @@ -7,7 +7,7 @@ from torch.utils.data import Dataset from .utils import zero_pad_sequences -def preprocess_data(data, input_template=None, input_key="input", output_key=None, apply_chat_template=None, multiturn=False): +def preprocess_data(data, input_template=None, input_key="input", output_key=None, apply_chat_template=None): if apply_chat_template: if output_key: prompt_message = data[input_key] @@ -51,7 +51,6 @@ class SFTDataset(Dataset): pretrain_mode=False, num_processors=8, # Specify the number of processors you want to use multiple_of=1, - multiturn=False, ) -> None: super().__init__() self.tokenizer = tokenizer @@ -59,7 +58,6 @@ class SFTDataset(Dataset): self.pretrain_mode = pretrain_mode self.max_length = max_length self.multiple_of = multiple_of - self.multiturn = multiturn # chat template self.input_template = input_template @@ -75,9 +73,7 @@ class SFTDataset(Dataset): # Parallel loading datasets processed_dataset = dataset.map( - self.process_data, - remove_columns=dataset.column_names, - num_proc=num_processors, + self.process_data, remove_columns=dataset.column_names, num_proc=num_processors ) processed_dataset = processed_dataset.filter(lambda x: x["prompt"] is not None) @@ -85,51 +81,15 @@ class SFTDataset(Dataset): self.prompts = processed_dataset["prompt"] self.responses = processed_dataset["response"] self.prompt_ids_lens = processed_dataset["prompt_ids_len"] - self.response_ranges = processed_dataset["response_ranges"] if self.multiturn else None def process_data(self, data): - if self.multiturn and self.output_key: - data[self.input_key].append(data[self.output_key]) - data[self.output_key] = None - - if self.multiturn: - assert not self.output_key or not data[self.output_key], "You should put the whole trajactory into data[input_key] and do not set output_key" - input_key = self.input_key - apply_chat_template = self.apply_chat_template - response_ranges = [] - for idx, message in enumerate(data[input_key]): - if message['role'] == 'assistant': - prompt = apply_chat_template(data[input_key][: idx], tokenize=False, add_generation_prompt=True) - response = apply_chat_template(data[input_key][: idx + 1], tokenize=False)[len(prompt):] - - start_idx = self.tokenizer( - prompt, - max_length=self.max_length, - padding=False, - truncation=True, - return_tensors="pt", - add_special_tokens=False, - )["attention_mask"].int().sum().item() - - end_idx = start_idx + self.tokenizer( - response, - max_length=self.max_length, - padding=False, - truncation=True, - return_tensors="pt", - add_special_tokens=False, - )["attention_mask"].int().sum().item() - 1 - response_ranges.append((start_idx, end_idx)) # left close right open - prompt, response = preprocess_data( data, None if self.pretrain_mode else self.input_template, self.input_key, self.output_key, apply_chat_template=None if self.pretrain_mode else self.apply_chat_template, - multiturn=self.multiturn, ) - if not self.pretrain_mode: prompt_token = self.tokenizer( prompt, @@ -147,7 +107,7 @@ class SFTDataset(Dataset): else: prompt_ids_len = 0 - return {"prompt": prompt, "response": response, "prompt_ids_len": prompt_ids_len, "response_ranges": response_ranges if self.multiturn else None} + return {"prompt": prompt, "response": response, "prompt_ids_len": prompt_ids_len} def __len__(self): length = len(self.prompts) @@ -178,7 +138,7 @@ class SFTDataset(Dataset): # to avoid EOS_token truncation input_token["input_ids"][0][-1] = self.tokenizer.eos_token_id input_token["attention_mask"][0][-1] = True - info = {"input": prompt, "output": response, "input_length": input_token["attention_mask"].int().sum().item(), "response_ranges": self.response_ranges[idx] if self.multiturn else None} + info = {"input": prompt, "output": response, "input_length": input_token["attention_mask"].int().sum().item()} return prompt_ids_len, input_token["input_ids"], input_token["attention_mask"], info @@ -203,19 +163,14 @@ class SFTDataset(Dataset): packed_input_ids = [] packed_attention_masks = [] prompt_ids_lens = [] - infos = {"input_length": [], "response_ranges": [] if self.multiturn else None} + infos = {"input_length": []} + index = 1 for prompt_ids_len, input_id, attention_mask, info in item_list: packed_input_ids.append(input_id.flatten()) packed_attention_masks.append(torch.full_like(input_id.flatten(), index)) prompt_ids_lens.append(prompt_ids_len) infos["input_length"].append(info["input_length"]) - if self.multiturn: - if len(infos["response_ranges"]) >= 1: - for i in range(len(info["response_ranges"])): - info["response_ranges"][i][0] += infos["response_ranges"][-1][-1][1] # end_index of the last response of the last item - info["response_ranges"][i][1] += infos["response_ranges"][-1][-1][1] - infos["response_ranges"].append(info["response_ranges"]) index += 1 packed_input_ids = torch.cat(packed_input_ids, dim=0).unsqueeze(0) diff --git a/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/models/model.py b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/models/model.py index 3d2102dc94..a7a71ae5b7 100644 --- a/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/models/model.py +++ b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/models/model.py @@ -3,17 +3,20 @@ from typing import Optional, Union import deepspeed import torch import torch.nn as nn -from flash_attn.utils.distributed import all_gather from peft import LoraConfig, get_peft_model from peft.tuners.lora import LoraLayer from transformers import AutoConfig, AutoModel, BitsAndBytesConfig from transformers.integrations.deepspeed import HfDeepSpeedConfig +from transformers.utils import is_flash_attn_2_available from openrlhf.utils.logging_utils import init_logger from .ring_attn_utils import convert_ring_attn_params from .utils import reset_position_ids +if is_flash_attn_2_available(): + from flash_attn.utils.distributed import all_gather + logger = init_logger(__name__) @@ -68,7 +71,12 @@ def get_llm_for_sequence_regression( config = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True) config.normalize_reward = normalize_reward - config._attn_implementation = "flash_attention_2" if use_flash_attention_2 else "eager" + if use_flash_attention_2 == "fa2": + config._attn_implementation = "flash_attention_2" + elif use_flash_attention_2 == "sdpa": + config._attn_implementation = "sdpa" + else: + config._attn_implementation = "eager" # Prioritize using the value_head_prefix in the model configuration. value_head_prefix = getattr(config, "value_head_prefix", value_head_prefix) diff --git a/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/trainer/ray/ppo_actor.py b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/trainer/ray/ppo_actor.py index 9661b0edb0..1d8bb59c45 100644 --- a/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/trainer/ray/ppo_actor.py +++ b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/trainer/ray/ppo_actor.py @@ -82,37 +82,24 @@ class ActorPPOTrainer(PPOTrainer): world_size = vllm_num_engines * vllm_tensor_parallel_size + 1 backend = getattr(self.strategy.args, "vllm_sync_backend", "nccl") - use_ray = getattr(self.strategy.args, "vllm_sync_with_ray", False) - group_name = "openrlhf" refs = [ engine.init_process_group.remote( master_address, master_port, i * vllm_tensor_parallel_size + 1, world_size, - group_name, + "openrlhf", backend=backend, - use_ray=use_ray, ) for i, engine in enumerate(self.vllm_engines) ] - if use_ray: - import ray.util.collective as collective - collective.init_collective_group( - world_size=world_size, - rank=0, - backend=backend, - group_name=group_name - ) - self._model_update_group = group_name - else: - self._model_update_group = init_process_group( - backend=backend, - init_method=f"tcp://{master_address}:{master_port}", - world_size=world_size, - rank=0, - group_name=group_name, - ) + self._model_update_group = init_process_group( + backend=backend, + init_method=f"tcp://{master_address}:{master_port}", + world_size=world_size, + rank=0, + group_name="openrlhf", + ) ray.get(refs) @@ -149,15 +136,8 @@ class ActorPPOTrainer(PPOTrainer): return self.training_step_actor(experience) def _broadcast_to_vllm(self): - use_prefix_cache = getattr(self.strategy.args, "enable_prefix_caching", False) - cache_reset_refs = [] - if use_prefix_cache and torch.distributed.get_rank() == 0: - # clear prefix cache - for engine in self.vllm_engines: - cache_reset_refs.append(engine.reset_prefix_cache.remote()) # avoid OOM torch.cuda.empty_cache() - use_ray = getattr(self.strategy.args, "vllm_sync_with_ray", False) model = self.actor.model.module count, num_params = 0, len(list(model.named_parameters())) for name, param in model.named_parameters(): @@ -174,14 +154,8 @@ class ActorPPOTrainer(PPOTrainer): # For ZeRO-3, allgather sharded parameter and broadcast to all vllm engines by rank 0 with deepspeed.zero.GatheredParameters([param], enabled=self.strategy.args.zero_stage == 3): if torch.distributed.get_rank() == 0: - if use_ray: - import ray.util.collective as collective - collective.broadcast(param.data, 0, group_name=self._model_update_group) - else: - torch.distributed.broadcast(param.data, 0, group=self._model_update_group) + torch.distributed.broadcast(param.data, 0, group=self._model_update_group) ray.get(refs) - if cache_reset_refs: - ray.get(cache_reset_refs) torch.distributed.barrier() def _save_checkpoint(self, args, tag, client_states): diff --git a/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/trainer/ray/vllm_engine.py b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/trainer/ray/vllm_engine.py index 889b034242..733c57effb 100644 --- a/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/trainer/ray/vllm_engine.py +++ b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/trainer/ray/vllm_engine.py @@ -35,11 +35,7 @@ class LLMRayActor: else: # RayGPUExecutor # See the patch https://github.com/vllm-project/vllm/commit/479d69fad0538f04cb22bf13e76ff91cfeb8a4e5 - if vllm.__version__ >= "0.4.3": - # https://github.com/vllm-project/vllm/commit/676a99982fe9aabe72fd52a91e08988a653a7359 - kwargs["distributed_executor_backend"] = "ray" - else: - kwargs["worker_use_ray"] = True + kwargs["worker_use_ray"] = True if vllm.__version__ > "0.6.4.post1": # https://github.com/vllm-project/vllm/pull/10555 @@ -60,14 +56,14 @@ class LLMRayActor: def generate(self, *args, **kwargs): return self.llm.generate(*args, **kwargs) - def init_process_group(self, master_address, master_port, rank_offset, world_size, group_name, backend, use_ray): + def init_process_group(self, master_address, master_port, rank_offset, world_size, group_name, backend): if self.use_gpu_executor: return self.llm.llm_engine.model_executor.driver_worker.init_process_group( - master_address, master_port, rank_offset, world_size, group_name, backend, use_ray + master_address, master_port, rank_offset, world_size, group_name, backend ) else: return self.llm.llm_engine.model_executor._run_workers( - "init_process_group", master_address, master_port, rank_offset, world_size, group_name, backend, use_ray + "init_process_group", master_address, master_port, rank_offset, world_size, group_name, backend ) def update_weight(self, name, dtype, shape, empty_cache=False): @@ -78,14 +74,6 @@ class LLMRayActor: else: return self.llm.llm_engine.model_executor._run_workers("update_weight", name, dtype, shape, empty_cache) - def reset_prefix_cache(self): - import vllm - if vllm.__version__ < "0.7.0": - # https://github.com/vllm-project/vllm/commit/7206ce4ce112ed117796a59045c968a6d353f691 - logger.warning("Reset prefix cache API is available only from vLLM 0.7.0!") - return - self.llm.llm_engine.reset_prefix_cache() - def stop_remote_worker_execution_loop(self): # Fix error for using 2 communication group # https://github.com/vllm-project/vllm/commit/eb6d3c264d0cd8e44dec16bca7947fbe96415ce9#diff-e1ad69e38e033accddfa5480ec808c4740eb39244d1ef51cc3407e20dde8cfd4 diff --git a/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/trainer/ray/vllm_worker_wrap.py b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/trainer/ray/vllm_worker_wrap.py index 2f324793d0..730dd12b85 100644 --- a/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/trainer/ray/vllm_worker_wrap.py +++ b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/trainer/ray/vllm_worker_wrap.py @@ -8,30 +8,19 @@ logger = init_logger(__name__) class WorkerWrap(Worker): - def init_process_group(self, master_address, master_port, rank_offset, world_size, group_name, backend="nccl", use_ray=False): + def init_process_group(self, master_address, master_port, rank_offset, world_size, group_name, backend="nccl"): """Init torch process group for model weights update""" assert torch.distributed.is_initialized(), f"default torch process group must be initialized" assert group_name != "", f"group name must not be empty" rank = torch.distributed.get_rank() + rank_offset - if use_ray: - import ray.util.collective as collective - collective.init_collective_group( - world_size=world_size, - rank=rank, - backend=backend, - group_name=group_name - ) - self._model_update_group = group_name - else: - self._model_update_group = init_process_group( - backend=backend, - init_method=f"tcp://{master_address}:{master_port}", - world_size=world_size, - rank=rank, - group_name=group_name, - ) - self._model_update_with_ray = use_ray + self._model_update_group = init_process_group( + backend=backend, + init_method=f"tcp://{master_address}:{master_port}", + world_size=world_size, + rank=rank, + group_name=group_name, + ) print( f"init_process_group: master_address={master_address}, master_port={master_port}, ", f"rank={rank}, world_size={world_size}, group_name={group_name}", @@ -44,11 +33,7 @@ class WorkerWrap(Worker): assert dtype == self.model_config.dtype, f"mismatch dtype: src {dtype}, dst {self.model_config.dtype}" weight = torch.empty(shape, dtype=dtype, device="cuda") - if self._model_update_with_ray: - import ray.util.collective as collective - collective.broadcast(weight, 0, group_name=self._model_update_group) - else: - torch.distributed.broadcast(weight, 0, group=self._model_update_group) + torch.distributed.broadcast(weight, 0, group=self._model_update_group) self.model_runner.model.load_weights(weights=[(name, weight)]) diff --git a/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/utils/vision_utils.py b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/utils/vision_utils.py new file mode 100644 index 0000000000..8c4e18fc90 --- /dev/null +++ b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/utils/vision_utils.py @@ -0,0 +1,703 @@ +import json +import math +import os +import re +from abc import ABC, abstractmethod +from copy import deepcopy +from dataclasses import dataclass, field +from io import BytesIO +from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Union + +import numpy as np +import torch +from PIL import Image +from PIL.Image import Image as ImageObject +from transformers import AutoConfig, AutoProcessor +from typing_extensions import override + +IGNORE_INDEX = -100 +ImageInput = Union[str, bytes, ImageObject] +SLOTS = Sequence[Union[str, Set[str], Dict[str, str]]] +IMAGE_PLACEHOLDER = os.environ.get("IMAGE_PLACEHOLDER", "") +VIDEO_PLACEHOLDER = os.environ.get("VIDEO_PLACEHOLDER", "