diff --git a/tests/st/python/cases_parallel/vllm_qwen3.py b/tests/st/python/cases_parallel/vllm_qwen3.py new file mode 100644 index 0000000000000000000000000000000000000000..e559d9faef913ed6cce9445ec677e679529e97a2 --- /dev/null +++ b/tests/st/python/cases_parallel/vllm_qwen3.py @@ -0,0 +1,113 @@ +#!/usr/bin/env python3 +# encoding: utf-8 +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +# isort:skip_file +"""test vllm llama3.""" +import os + +import pytest + +from tests.st.python import set_env + +env_manager = set_env.EnvVarManager() +# def env +env_vars = { + "ASCEND_CUSTOM_PATH": os.path.expandvars("$ASCEND_HOME_PATH/../"), + "MS_ENABLE_LCCL": "off", + "HCCL_OP_EXPANSION_MODE": "AIV", + "MS_ALLOC_CONF": "enable_vmm:True", + "LCCL_DETERMINISTIC": "1", + "HCCL_DETERMINISTIC": "true", + "ATB_MATMUL_SHUFFLE_K_ENABLE": "0", + "ATB_LLM_LCOC_ENABLE": "0", + "VLLM_USE_V1": "1", + "HCCL_IF_BASE_PORT": "60000" +} +# set env +env_manager.setup_ai_environment(env_vars) +import vllm_mindspore +from vllm import LLM, SamplingParams + + +def test_vllm_qwen3_8b(): + """ + test case qwen3 8B + """ + + # Sample prompts. + prompts = [ + "<|im_start|>user\n将文本分类为中性、负面或正面。 " + "\n文本:我认为这次假期还可以。 \n情感:<|im_end|>\n<|im_start|>assistant\n", + ] + + # Create a sampling params object. + sampling_params = SamplingParams(temperature=0.0, max_tokens=10, top_k=1) + + # Create an LLM. + llm = LLM(model="/home/workspace/mindspore_dataset/weight/Qwen3-8B", + gpu_memory_utilization=0.9, + tensor_parallel_size=1, + max_model_len=4096) + # Generate texts from the prompts. The output is a list of RequestOutput objects + # that contain the prompt, generated text, and other information. + outputs = llm.generate(prompts, sampling_params) + except_list = ['\n好的,我现在需要处理用户的查询,'] + # Print the outputs. + for i, output in enumerate(outputs): + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + assert generated_text == except_list[ + i], f"Expected: {except_list[i]}, but got: {generated_text}" + + # unset env + env_manager.unset_all() + + +def test_vllm_qwen3_0_6b(): + """ + test case qwen3 0.6B + """ + + # Sample prompts. + prompts = [ + "<|im_start|>user\n将文本分类为中性、负面或正面。 " + "\n文本:我认为这次假期还可以。 \n情感:<|im_end|>\n<|im_start|>assistant\n\n\n\n\n", + ] + + # Create a sampling params object. + sampling_params = SamplingParams(temperature=0.0, max_tokens=10, top_k=1) + + # Create an LLM. + llm = LLM(model="/home/workspace/mindspore_dataset/weight/Qwen3-0.6B", + gpu_memory_utilization=0.9, + tensor_parallel_size=1, + max_model_len=4096) + # Generate texts from the prompts. The output is a list of RequestOutput objects + # that contain the prompt, generated text, and other information. + outputs = llm.generate(prompts, sampling_params) + except_list = ['情感:中性'] + # Print the outputs. + for i, output in enumerate(outputs): + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + assert generated_text == except_list[ + i], f"Expected: {except_list[i]}, but got: {generated_text}" + + # unset env + env_manager.unset_all() diff --git a/tests/st/python/test_cases_parallel.py b/tests/st/python/test_cases_parallel.py index b5f93c6a72099f87b09700671d291f6aeeb4af15..d6f282464e85dc03aad05842626291aa373e68fc 100644 --- a/tests/st/python/test_cases_parallel.py +++ b/tests/st/python/test_cases_parallel.py @@ -243,6 +243,16 @@ def test_cases_parallel_part5(): "pytest -s -v cases_parallel/vllm_llama3.py::test_vllm_llama3_1b " "> vllm_llama3_1b_test_vllm_llama3.log", "vllm_llama3_1b_test_vllm_llama3.log"), + ("export ASCEND_RT_VISIBLE_DEVICES=6 && export LCAL_COMM_ID=127.0.0.1:10072 && " + "export HCCL_IF_BASE_PORT=61008 && " + "pytest -s -v cases_parallel/vllm_qwen3.py::test_vllm_qwen3_8b " + "> vllm_qwen3_test_vllm_qwen3_8b.log", + "vllm_qwen3_test_vllm_qwen3_8b.log"), + ("export ASCEND_RT_VISIBLE_DEVICES=7 && export LCAL_COMM_ID=127.0.0.1:10073 && " + "export HCCL_IF_BASE_PORT=61010 && " + "pytest -s -v cases_parallel/vllm_qwen3.py::test_vllm_qwen3_0_6b " + "> vllm_qwen3_test_vllm_qwen3_0_6b.log", + "vllm_qwen3_test_vllm_qwen3_0_6b.log") ] with Pool(len(commands)) as pool: diff --git a/vllm_mindspore/model_executor/layers/rotary_embedding.py b/vllm_mindspore/model_executor/layers/rotary_embedding.py index 7470233475254ccccdf677d5627a6f3bd59f6408..1a638051dafca163054e29626e504c74fcfbacc0 100644 --- a/vllm_mindspore/model_executor/layers/rotary_embedding.py +++ b/vllm_mindspore/model_executor/layers/rotary_embedding.py @@ -603,6 +603,110 @@ class InferMRotaryEmbedding(InferRotaryEmbedding): batch_valid_length) +def _yarn_get_mscale(scale: float = 1) -> float: + if scale <= 1: + return 1.0 + return 0.1 * math.log(scale) + 1.0 + + +def _yarn_find_correction_dim(num_rotations: int, + dim: int, + base: float = 10000, + max_position_embeddings: int = 2048) -> float: + return (dim * math.log(max_position_embeddings / + (num_rotations * 2 * math.pi))) / (2 * + math.log(base)) + + +# Find dim range bounds based on rotations +def _yarn_find_correction_range( + low_rot: int, + high_rot: int, + dim: int, + base: float = 10000, + max_position_embeddings: int = 2048) -> Tuple[int, int]: + low = math.floor( + _yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings)) + high = math.ceil( + _yarn_find_correction_dim(high_rot, dim, base, + max_position_embeddings)) + return max(low, 0), min(high, dim - 1) # Clamp values just in case + + +def _yarn_linear_ramp_mask(low: float, high: float, dim: int, + dtype: np.dtype) -> np.ndarray: + if low == high: + high += 0.001 # Prevent singularity + + linear_func = (np.arange(dim, dtype=dtype) - low) / (high - low) + ramp_func = np.clip(linear_func, 0, 1) + return ramp_func + + +class InferYaRNScalingRotaryEmbedding(InferRotaryEmbedding): + + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: int, + is_neox_style: bool, + scaling_factor: float, + dtype, + *, + extrapolation_factor: float = 1, + attn_factor: float = 1, + beta_fast: int = 32, + beta_slow: int = 1, + ) -> None: + self.scaling_factor = scaling_factor + self.extrapolation_factor = extrapolation_factor + self.attn_factor = attn_factor + self.beta_fast = beta_fast + self.beta_slow = beta_slow + # Get n-d magnitude scaling corrected for interpolation + self.mscale = float( + _yarn_get_mscale(self.scaling_factor) * attn_factor) + super().__init__(head_size, rotary_dim, max_position_embeddings, base, + is_neox_style, dtype) + + def _compute_inv_freq(self, scaling_factor: float) -> Tensor: + pos_freqs = self.base**( + np.arange(0, self.rotary_dim, 2, dtype=np.float32) / + self.rotary_dim) + inv_freq_extrapolation = 1.0 / pos_freqs + inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs) + + low, high = _yarn_find_correction_range(self.beta_fast, self.beta_slow, + self.rotary_dim, self.base, + self.max_position_embeddings) + # Get n-d rotational scaling corrected for extrapolation + inv_freq_mask = ( + 1 - _yarn_linear_ramp_mask( + low, + high, + self.rotary_dim // 2, + dtype=np.float32 # type: ignore[arg-type] + )) * self.extrapolation_factor + inv_freq = inv_freq_interpolation * ( + 1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask + return inv_freq + + def _compute_cos_sin_cache(self) -> Tuple[Tensor, Tensor]: + freqs = self._compute_inv_freq(self.scaling_factor) + t = np.arange(self.max_position_embeddings * + self.scaling_factor).astype(np.float32) + self.freqs = Tensor(freqs.reshape(1, 1, 1, -1), dtype=self.dtype) + freqs = np.outer(t, freqs) # (max_position_embedding, head_dim // 2) + emb = np.concatenate((freqs, freqs), axis=-1) + freqs_cos = np.cos(emb) * self.mscale # (seq_len, head_dim) + freqs_sin = np.sin(emb) * self.mscale # (seq_len, head_dim) + freqs_cos = Tensor(freqs_cos, dtype=self.dtype) + freqs_sin = Tensor(freqs_sin, dtype=self.dtype) + return freqs_cos, freqs_sin + + _ROPE_DICT: Dict[Tuple, Union[InferRotaryEmbedding, RotaryEmbedding]] = {} @@ -671,6 +775,19 @@ def get_rope( ) else: raise NotImplementedError + elif scaling_type == "yarn": + scaling_factor = rope_scaling["factor"] + original_max_position = rope_scaling[ + "original_max_position_embeddings"] + extra_kwargs = { + k: v + for k, v in rope_scaling.items() + if k in ("extrapolation_factor", "attn_factor", "beta_fast", + "beta_slow") + } + rotary_emb = InferYaRNScalingRotaryEmbedding( + head_size, rotary_dim, original_max_position, base, + is_neox_style, scaling_factor, dtype, **extra_kwargs) else: raise NotImplementedError diff --git a/vllm_mindspore/model_executor/models/qwen2.py b/vllm_mindspore/model_executor/models/qwen2.py index 87c54c2126c2456cf7ec73ec123f8b3050386570..3840ecbb5351f6447d2f51d39dfe0de48e0131d9 100644 --- a/vllm_mindspore/model_executor/models/qwen2.py +++ b/vllm_mindspore/model_executor/models/qwen2.py @@ -268,7 +268,11 @@ class Qwen2DecoderLayer(nn.Cell): class Qwen2Model(nn.Cell): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + def __init__(self, + *, + vllm_config: VllmConfig, + prefix: str = "", + decoder_layer_type: type[nn.Cell] = Qwen2DecoderLayer): super().__init__() config = vllm_config.model_config.hf_config @@ -292,10 +296,10 @@ class Qwen2Model(nn.Cell): self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, - lambda prefix: Qwen2DecoderLayer(config=config, - cache_config=cache_config, - quant_config=quant_config, - prefix=prefix), + lambda prefix: decoder_layer_type(config=config, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix), prefix=f"{prefix}.layers", ) diff --git a/vllm_mindspore/model_executor/models/qwen3.py b/vllm_mindspore/model_executor/models/qwen3.py new file mode 100644 index 0000000000000000000000000000000000000000..b9200e2ff97f6a71709c8e9093daca083fb3497f --- /dev/null +++ b/vllm_mindspore/model_executor/models/qwen3.py @@ -0,0 +1,320 @@ +#!/usr/bin/env python3 +# Copyright 2025 Huawei Technologies Co., Ltd +# Copyright 2025 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Inference-only Qwen3 model compatible with HuggingFace weights.""" +from collections.abc import Iterable +from typing import Any, Dict, Optional, Union + +from mindspore import Tensor, nn +from transformers import Qwen3Config +from vllm.attention import AttentionType +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.logger import init_logger +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from vllm_mindspore.attention import Attention +from vllm_mindspore.model_executor.layers.layernorm import RMSNorm +from vllm_mindspore.model_executor.layers.linear import (QKVParallelLinear, + RowParallelLinear) +from vllm_mindspore.model_executor.layers.logits_processor import ( + LogitsProcessor) +from vllm_mindspore.model_executor.layers.sampler import (SamplerOutput, + get_sampler) +from vllm_mindspore.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead) +from vllm_mindspore.model_executor.models.model_base import NativeModel +from vllm_mindspore.model_executor.models.utils import (PPMissingLayer, + maybe_prefix) + +from vllm_mindspore.model_executor.layers.rotary_embedding import ( # type: ignore[attr-defined] # isort: skip + get_rope) + +from vllm_mindspore.model_executor.models.qwen2 import ( # type: ignore[attr-defined] # isort: skip + Qwen2MLP as Qwen3MLP) +from vllm_mindspore.model_executor.models.qwen2 import ( # type: ignore[attr-defined] # isort: skip + Qwen2Model) + +logger = init_logger(__name__) + + +class Qwen3Attention(nn.Cell): + + def __init__(self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + max_position: int = 4096 * 32, + head_dim: Optional[int] = None, + rms_norm_eps: float = 1e-06, + qkv_bias: bool = False, + rope_theta: float = 10000, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + rope_scaling: Optional[Dict[str, Any]] = None, + prefix: str = "", + attn_type: str = AttentionType.DECODER) -> None: + super().__init__() + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = head_dim or hidden_size // self.total_num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + + self.qkv_proj = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=qkv_bias, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position, + base=self.rope_theta, # type: ignore[arg-type] + rope_scaling=rope_scaling, + ) + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + attn_type=attn_type) + self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) + self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) + + def construct( + self, + positions: Tensor, + hidden_states: Tensor, + key_cache: Tensor, + value_cache: Tensor, + is_prefill: bool, + slot_mapping: Tensor, + attn_mask: Tensor, + batch_valid_length: Tensor, + q_seq_lens: Tensor, + block_tables: Tensor, + ) -> Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + # Add qk-norm + q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim, + self.head_dim) + q_by_head = self.q_norm(q_by_head) + q = q_by_head.view(q.shape) + k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim, + self.head_dim) + k_by_head = self.k_norm(k_by_head) + k = k_by_head.view(k.shape) + q, k = self.rotary_emb(positions, q, k, batch_valid_length, is_prefill) + attn_output = self.attn(q, k, v, key_cache, value_cache, is_prefill, + slot_mapping, attn_mask, batch_valid_length, + q_seq_lens, block_tables) + output, _ = self.o_proj(attn_output) + return output + + +class Qwen3DecoderLayer(nn.Cell): + + def __init__( + self, + config: Qwen3Config, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + # Requires transformers > 4.32.0 + rope_theta = getattr(config, "rope_theta", 1000000) + rope_scaling = getattr(config, "rope_scaling", None) + + # By default, Qwen3 uses causal attention as it is a decoder-only model. + # You can override the HF config with `is_causal=False` to enable + # bidirectional attention, which is used in some embedding models + # (e.g. Alibaba-NLP/gte-Qwen3-7B-instruct) + if getattr(config, "is_causal", True): + attn_type = AttentionType.DECODER + else: + attn_type = AttentionType.ENCODER_ONLY + + self.self_attn = Qwen3Attention( + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + max_position=config.max_position_embeddings, + num_kv_heads=config.num_key_value_heads, + rope_theta=rope_theta, + rms_norm_eps=config.rms_norm_eps, + qkv_bias=getattr(config, 'attention_bias', False), + head_dim=getattr(config, 'head_dim', None), + cache_config=cache_config, + quant_config=quant_config, + rope_scaling=rope_scaling, + prefix=f"{prefix}.self_attn", + attn_type=attn_type, + ) + self.mlp = Qwen3MLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + ) + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def construct( + self, + positions: Tensor, + hidden_states: Tensor, + key_cache: Tensor, + value_cache: Tensor, + is_prefill: bool, + slot_mapping: Tensor, + attn_mask: Tensor, + batch_valid_length: Tensor, + q_seq_lens: Tensor, + block_tables: Tensor, + residual: Optional[Tensor], + ) -> tuple[Tensor, Tensor]: + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + hidden_states = self.self_attn(positions, hidden_states, key_cache, + value_cache, is_prefill, slot_mapping, + attn_mask, batch_valid_length, + q_seq_lens, block_tables) + + # Fully Connected + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + hidden_states = self.mlp(hidden_states) + return hidden_states, residual + + +class Qwen3Model(Qwen2Model): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__(vllm_config=vllm_config, + prefix=prefix, + decoder_layer_type=Qwen3DecoderLayer) + + +class Qwen3ForCausalLM(NativeModel): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__(vllm_config=vllm_config, prefix=prefix) + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + + self.config = config + self.lora_config = lora_config + + self.quant_config = quant_config + self.model = Qwen3Model(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + + if get_pp_group().is_last_rank: + if config.tie_word_embeddings: + self.lm_head = self.model.embed_tokens + else: + self.lm_head = ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=maybe_prefix( + prefix, "lm_head")) + self.sampler = get_sampler() + else: + self.lm_head = PPMissingLayer() + + self.logits_processor = LogitsProcessor(config.vocab_size) + + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + self.common_preprocess(vllm_config, prefix) + + def get_input_embeddings(self, input_ids: Tensor) -> Tensor: + return self.model.get_input_embeddings(input_ids) + + def forward(self, + input_ids: Tensor, + positions: Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[Tensor] = None, + **kwargs) -> Union[Tensor, IntermediateTensors]: + hidden_states = self.exec_model(input_ids, positions, + intermediate_tensors, inputs_embeds) + return hidden_states + + def sample(self, logits: Tensor, + sampling_metadata: SamplingMetadata) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def compute_logits( + self, + hidden_states: Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def load_weights(self, weights: Iterable[tuple[str, Tensor]]) -> set[str]: + params_dict = self.get_params_dict() + load_params = self.model.load_weights(weights, params_dict) + if self.config.tie_word_embeddings: + load_params.add("lm_head.weight") + return load_params diff --git a/vllm_mindspore/model_executor/models/registry.py b/vllm_mindspore/model_executor/models/registry.py index 009d84a06f124d270b88f26e98697ae4285cd7f5..6cdcbba76ebbe204c441f79857daf24af1e084ee 100644 --- a/vllm_mindspore/model_executor/models/registry.py +++ b/vllm_mindspore/model_executor/models/registry.py @@ -30,6 +30,7 @@ _NATIVE_MODELS = { "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"), "Qwen2_5_VLForConditionalGeneration": ("qwen2_5_vl", "Qwen2_5_VLForConditionalGeneration"), + "Qwen3ForCausalLM": ("qwen3", "Qwen3ForCausalLM"), } _MINDFORMERS_MODELS = {