From 8e9f93ed33e51977b10d0dff23ea9fb0e0639e25 Mon Sep 17 00:00:00 2001 From: liningl Date: Mon, 21 Jul 2025 14:31:56 +0800 Subject: [PATCH] add quantization config --- vllm_mindspore/__init__.py | 16 + vllm_mindspore/engine/arg_utils.py | 427 +++++++++++++++++- .../layers/quantization/__init__.py | 49 ++ .../layers/quantization/base_config.py | 4 +- .../model_loader/weight_utils.py | 107 ++++- 5 files changed, 598 insertions(+), 5 deletions(-) diff --git a/vllm_mindspore/__init__.py b/vllm_mindspore/__init__.py index a4e049379..fd79f9223 100644 --- a/vllm_mindspore/__init__.py +++ b/vllm_mindspore/__init__.py @@ -475,4 +475,20 @@ from vllm_mindspore.entrypoints.__main__ import ( patch_server_run_api_server_worker_proc() +from vllm_mindspore.model_executor.layers.quantization import ( + get_quantization_config) + +vllm.model_executor.layers.quantization.get_quantization_config \ + = get_quantization_config +vllm.config.get_quantization_config = get_quantization_config +vllm.model_executor.model_loader.weight_utils.get_quantization_config \ + = get_quantization_config + +from vllm_mindspore.model_executor.model_loader.weight_utils import ( + get_quant_config) + +vllm.model_executor.model_loader.weight_utils.get_quant_config \ + = get_quant_config +vllm.config.get_quant_config = get_quant_config + check_ready() diff --git a/vllm_mindspore/engine/arg_utils.py b/vllm_mindspore/engine/arg_utils.py index a7bfc2204..400bca4e1 100644 --- a/vllm_mindspore/engine/arg_utils.py +++ b/vllm_mindspore/engine/arg_utils.py @@ -19,15 +19,26 @@ # limitations under the License. """Adaption for arguments utils.""" +import argparse +import json +import sys import threading +from itertools import permutations from typing import get_args import torch import vllm.envs as envs -from vllm.config import (GuidedDecodingBackendV1, LoadFormat, ModelConfig, - ParallelConfig, SchedulerConfig) +from vllm.config import (CacheConfig, ConfigFormat, DecodingConfig, + DetailedTraceModules, DeviceConfig, + GuidedDecodingBackendV1, LoadConfig, LoadFormat, + LoRAConfig, ModelConfig, ModelImpl, MultiModalConfig, + ObservabilityConfig, ParallelConfig, + PromptAdapterConfig, SchedulerConfig, + SpeculativeConfig, TokenizerPoolConfig, VllmConfig) from vllm.engine.arg_utils import (EngineArgs, _raise_or_fallback, - _warn_or_fallback) + _warn_or_fallback, get_kwargs) +from vllm.reasoning import ReasoningParserManager +from vllm.utils import FlexibleArgumentParser def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool: @@ -224,3 +235,413 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool: ############################################################# return True + + +def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: + """Shared CLI arguments for vLLM engine.""" + + # Model arguments + model_kwargs = get_kwargs(ModelConfig) + model_group = parser.add_argument_group( + title="ModelConfig", + description=ModelConfig.__doc__, + ) + if not ('serve' in sys.argv[1:] and '--help' in sys.argv[1:]): + model_group.add_argument("--model", **model_kwargs["model"]) + model_group.add_argument("--task", **model_kwargs["task"]) + model_group.add_argument("--tokenizer", **model_kwargs["tokenizer"]) + model_group.add_argument("--tokenizer-mode", + **model_kwargs["tokenizer_mode"]) + model_group.add_argument("--trust-remote-code", + **model_kwargs["trust_remote_code"]) + model_group.add_argument("--dtype", **model_kwargs["dtype"]) + model_group.add_argument("--seed", **model_kwargs["seed"]) + model_group.add_argument("--hf-config-path", + **model_kwargs["hf_config_path"]) + model_group.add_argument("--allowed-local-media-path", + **model_kwargs["allowed_local_media_path"]) + model_group.add_argument("--revision", **model_kwargs["revision"]) + model_group.add_argument("--code-revision", + **model_kwargs["code_revision"]) + model_group.add_argument("--rope-scaling", **model_kwargs["rope_scaling"]) + model_group.add_argument("--rope-theta", **model_kwargs["rope_theta"]) + model_group.add_argument("--tokenizer-revision", + **model_kwargs["tokenizer_revision"]) + model_group.add_argument("--max-model-len", + **model_kwargs["max_model_len"]) + model_group.add_argument("--quantization", "-q", + **model_kwargs["quantization"]) + model_group.add_argument("--enforce-eager", + **model_kwargs["enforce_eager"]) + model_group.add_argument("--max-seq-len-to-capture", + **model_kwargs["max_seq_len_to_capture"]) + model_group.add_argument("--max-logprobs", **model_kwargs["max_logprobs"]) + model_group.add_argument("--disable-sliding-window", + **model_kwargs["disable_sliding_window"]) + model_group.add_argument("--disable-cascade-attn", + **model_kwargs["disable_cascade_attn"]) + model_group.add_argument("--skip-tokenizer-init", + **model_kwargs["skip_tokenizer_init"]) + model_group.add_argument("--enable-prompt-embeds", + **model_kwargs["enable_prompt_embeds"]) + model_group.add_argument("--served-model-name", + **model_kwargs["served_model_name"]) + # This one is a special case because it is the + # opposite of ModelConfig.use_async_output_proc + model_group.add_argument( + "--disable-async-output-proc", + action="store_true", + default=EngineArgs.disable_async_output_proc, + help="Disable async output processing. This may result in " + "lower performance.") + model_group.add_argument("--config-format", + choices=[f.value for f in ConfigFormat], + **model_kwargs["config_format"]) + # This one is a special case because it can bool + # or str. TODO: Handle this in get_kwargs + model_group.add_argument("--hf-token", + type=str, + nargs="?", + const=True, + default=model_kwargs["hf_token"]["default"], + help=model_kwargs["hf_token"]["help"]) + model_group.add_argument("--hf-overrides", **model_kwargs["hf_overrides"]) + model_group.add_argument("--override-neuron-config", + **model_kwargs["override_neuron_config"]) + model_group.add_argument("--override-pooler-config", + **model_kwargs["override_pooler_config"]) + model_group.add_argument("--logits-processor-pattern", + **model_kwargs["logits_processor_pattern"]) + model_group.add_argument("--generation-config", + **model_kwargs["generation_config"]) + model_group.add_argument("--override-generation-config", + **model_kwargs["override_generation_config"]) + model_group.add_argument("--enable-sleep-mode", + **model_kwargs["enable_sleep_mode"]) + model_group.add_argument("--model-impl", + choices=[f.value for f in ModelImpl], + **model_kwargs["model_impl"]) + + # Model loading arguments + load_kwargs = get_kwargs(LoadConfig) + load_group = parser.add_argument_group( + title="LoadConfig", + description=LoadConfig.__doc__, + ) + load_group.add_argument("--load-format", + choices=[f.value for f in LoadFormat], + **load_kwargs["load_format"]) + load_group.add_argument("--download-dir", **load_kwargs["download_dir"]) + load_group.add_argument("--model-loader-extra-config", + **load_kwargs["model_loader_extra_config"]) + load_group.add_argument("--ignore-patterns", + **load_kwargs["ignore_patterns"]) + load_group.add_argument("--use-tqdm-on-load", + **load_kwargs["use_tqdm_on_load"]) + load_group.add_argument( + "--qlora-adapter-name-or-path", + type=str, + default=None, + help="The `--qlora-adapter-name-or-path` has no effect, do not set" + " it, and it will be removed in v0.10.0.", + deprecated=True, + ) + load_group.add_argument('--pt-load-map-location', + **load_kwargs["pt_load_map_location"]) + + # Guided decoding arguments + guided_decoding_kwargs = get_kwargs(DecodingConfig) + guided_decoding_group = parser.add_argument_group( + title="DecodingConfig", + description=DecodingConfig.__doc__, + ) + guided_decoding_group.add_argument("--guided-decoding-backend", + **guided_decoding_kwargs["backend"]) + guided_decoding_group.add_argument( + "--guided-decoding-disable-fallback", + **guided_decoding_kwargs["disable_fallback"]) + guided_decoding_group.add_argument( + "--guided-decoding-disable-any-whitespace", + **guided_decoding_kwargs["disable_any_whitespace"]) + guided_decoding_group.add_argument( + "--guided-decoding-disable-additional-properties", + **guided_decoding_kwargs["disable_additional_properties"]) + guided_decoding_group.add_argument( + "--enable-reasoning", + action=argparse.BooleanOptionalAction, + deprecated=True, + help="[DEPRECATED] The `--enable-reasoning` flag is deprecated as " + "of v0.9.0. Use `--reasoning-parser` to specify the reasoning " + "parser backend instead. This flag (`--enable-reasoning`) will be " + "removed in v0.10.0. When `--reasoning-parser` is specified, " + "reasoning mode is automatically enabled.") + guided_decoding_group.add_argument( + "--reasoning-parser", + # This choices is a special case because it's not static + choices=list(ReasoningParserManager.reasoning_parsers), + **guided_decoding_kwargs["reasoning_backend"]) + + # Parallel arguments + parallel_kwargs = get_kwargs(ParallelConfig) + parallel_group = parser.add_argument_group( + title="ParallelConfig", + description=ParallelConfig.__doc__, + ) + parallel_group.add_argument( + "--distributed-executor-backend", + **parallel_kwargs["distributed_executor_backend"]) + parallel_group.add_argument("--pipeline-parallel-size", "-pp", + **parallel_kwargs["pipeline_parallel_size"]) + parallel_group.add_argument("--tensor-parallel-size", "-tp", + **parallel_kwargs["tensor_parallel_size"]) + parallel_group.add_argument("--data-parallel-size", "-dp", + **parallel_kwargs["data_parallel_size"]) + parallel_group.add_argument('--data-parallel-size-local', + '-dpl', + type=int, + help='Number of data parallel replicas ' + 'to run on this node.') + parallel_group.add_argument('--data-parallel-address', + '-dpa', + type=str, + help='Address of data parallel cluster ' + 'head-node.') + parallel_group.add_argument('--data-parallel-rpc-port', + '-dpp', + type=int, + help='Port for data parallel RPC ' + 'communication.') + parallel_group.add_argument('--data-parallel-backend', + '-dpb', + type=str, + default='mp', + help='Backend for data parallel, either ' + '"mp" or "ray".') + parallel_group.add_argument("--enable-expert-parallel", + **parallel_kwargs["enable_expert_parallel"]) + parallel_group.add_argument( + "--max-parallel-loading-workers", + **parallel_kwargs["max_parallel_loading_workers"]) + parallel_group.add_argument("--ray-workers-use-nsight", + **parallel_kwargs["ray_workers_use_nsight"]) + parallel_group.add_argument("--disable-custom-all-reduce", + **parallel_kwargs["disable_custom_all_reduce"]) + parallel_group.add_argument("--worker-cls", + **parallel_kwargs["worker_cls"]) + parallel_group.add_argument("--worker-extension-cls", + **parallel_kwargs["worker_extension_cls"]) + parallel_group.add_argument( + "--enable-multimodal-encoder-data-parallel", + **parallel_kwargs["enable_multimodal_encoder_data_parallel"]) + + # KV cache arguments + cache_kwargs = get_kwargs(CacheConfig) + cache_group = parser.add_argument_group( + title="CacheConfig", + description=CacheConfig.__doc__, + ) + cache_group.add_argument("--block-size", **cache_kwargs["block_size"]) + cache_group.add_argument("--gpu-memory-utilization", + **cache_kwargs["gpu_memory_utilization"]) + cache_group.add_argument("--swap-space", **cache_kwargs["swap_space"]) + cache_group.add_argument("--kv-cache-dtype", **cache_kwargs["cache_dtype"]) + cache_group.add_argument("--num-gpu-blocks-override", + **cache_kwargs["num_gpu_blocks_override"]) + cache_group.add_argument("--enable-prefix-caching", + **cache_kwargs["enable_prefix_caching"]) + cache_group.add_argument("--prefix-caching-hash-algo", + **cache_kwargs["prefix_caching_hash_algo"]) + cache_group.add_argument("--cpu-offload-gb", + **cache_kwargs["cpu_offload_gb"]) + cache_group.add_argument("--calculate-kv-scales", + **cache_kwargs["calculate_kv_scales"]) + + # Tokenizer arguments + tokenizer_kwargs = get_kwargs(TokenizerPoolConfig) + tokenizer_group = parser.add_argument_group( + title="TokenizerPoolConfig", + description=TokenizerPoolConfig.__doc__, + ) + tokenizer_group.add_argument("--tokenizer-pool-size", + **tokenizer_kwargs["pool_size"]) + tokenizer_group.add_argument("--tokenizer-pool-type", + **tokenizer_kwargs["pool_type"]) + tokenizer_group.add_argument("--tokenizer-pool-extra-config", + **tokenizer_kwargs["extra_config"]) + + # Multimodal related configs + multimodal_kwargs = get_kwargs(MultiModalConfig) + multimodal_group = parser.add_argument_group( + title="MultiModalConfig", + description=MultiModalConfig.__doc__, + ) + multimodal_group.add_argument("--limit-mm-per-prompt", + **multimodal_kwargs["limit_per_prompt"]) + multimodal_group.add_argument("--mm-processor-kwargs", + **multimodal_kwargs["mm_processor_kwargs"]) + multimodal_group.add_argument( + "--disable-mm-preprocessor-cache", + **multimodal_kwargs["disable_mm_preprocessor_cache"]) + + # LoRA related configs + lora_kwargs = get_kwargs(LoRAConfig) + lora_group = parser.add_argument_group( + title="LoRAConfig", + description=LoRAConfig.__doc__, + ) + lora_group.add_argument("--enable-lora", + action=argparse.BooleanOptionalAction, + help="If True, enable handling of LoRA adapters.") + lora_group.add_argument("--enable-lora-bias", + **lora_kwargs["bias_enabled"]) + lora_group.add_argument("--max-loras", **lora_kwargs["max_loras"]) + lora_group.add_argument("--max-lora-rank", **lora_kwargs["max_lora_rank"]) + lora_group.add_argument("--lora-extra-vocab-size", + **lora_kwargs["lora_extra_vocab_size"]) + lora_group.add_argument( + "--lora-dtype", + **lora_kwargs["lora_dtype"], + ) + lora_group.add_argument("--long-lora-scaling-factors", + **lora_kwargs["long_lora_scaling_factors"]) + lora_group.add_argument("--max-cpu-loras", **lora_kwargs["max_cpu_loras"]) + lora_group.add_argument("--fully-sharded-loras", + **lora_kwargs["fully_sharded_loras"]) + + # PromptAdapter related configs + prompt_adapter_kwargs = get_kwargs(PromptAdapterConfig) + prompt_adapter_group = parser.add_argument_group( + title="PromptAdapterConfig", + description=PromptAdapterConfig.__doc__, + ) + prompt_adapter_group.add_argument( + "--enable-prompt-adapter", + action=argparse.BooleanOptionalAction, + help="If True, enable handling of PromptAdapters.") + prompt_adapter_group.add_argument( + "--max-prompt-adapters", + **prompt_adapter_kwargs["max_prompt_adapters"]) + prompt_adapter_group.add_argument( + "--max-prompt-adapter-token", + **prompt_adapter_kwargs["max_prompt_adapter_token"]) + + # Device arguments + device_kwargs = get_kwargs(DeviceConfig) + device_group = parser.add_argument_group( + title="DeviceConfig", + description=DeviceConfig.__doc__, + ) + device_group.add_argument("--device", + **device_kwargs["device"], + deprecated=True) + + # Speculative arguments + speculative_group = parser.add_argument_group( + title="SpeculativeConfig", + description=SpeculativeConfig.__doc__, + ) + speculative_group.add_argument( + "--speculative-config", + type=json.loads, + default=None, + help="The configurations for speculative decoding. Should be a " + "JSON string.") + + # Observability arguments + observability_kwargs = get_kwargs(ObservabilityConfig) + observability_group = parser.add_argument_group( + title="ObservabilityConfig", + description=ObservabilityConfig.__doc__, + ) + observability_group.add_argument( + "--show-hidden-metrics-for-version", + **observability_kwargs["show_hidden_metrics_for_version"]) + observability_group.add_argument( + "--otlp-traces-endpoint", + **observability_kwargs["otlp_traces_endpoint"]) + # TODO: generalise this special case + choices = observability_kwargs["collect_detailed_traces"]["choices"] + metavar = f"{{{','.join(choices)}}}" + observability_kwargs["collect_detailed_traces"]["metavar"] = metavar + observability_kwargs["collect_detailed_traces"]["choices"] += [ + ",".join(p) for p in permutations(get_args(DetailedTraceModules), r=2) + ] + observability_group.add_argument( + "--collect-detailed-traces", + **observability_kwargs["collect_detailed_traces"]) + + # Scheduler arguments + scheduler_kwargs = get_kwargs(SchedulerConfig) + scheduler_group = parser.add_argument_group( + title="SchedulerConfig", + description=SchedulerConfig.__doc__, + ) + scheduler_group.add_argument("--max-num-batched-tokens", + **scheduler_kwargs["max_num_batched_tokens"]) + scheduler_group.add_argument("--max-num-seqs", + **scheduler_kwargs["max_num_seqs"]) + scheduler_group.add_argument( + "--max-num-partial-prefills", + **scheduler_kwargs["max_num_partial_prefills"]) + scheduler_group.add_argument( + "--max-long-partial-prefills", + **scheduler_kwargs["max_long_partial_prefills"]) + scheduler_group.add_argument('--cuda-graph-sizes', + **scheduler_kwargs["cuda_graph_sizes"]) + scheduler_group.add_argument( + "--long-prefill-token-threshold", + **scheduler_kwargs["long_prefill_token_threshold"]) + scheduler_group.add_argument("--num-lookahead-slots", + **scheduler_kwargs["num_lookahead_slots"]) + scheduler_group.add_argument("--scheduler-delay-factor", + **scheduler_kwargs["delay_factor"]) + scheduler_group.add_argument("--preemption-mode", + **scheduler_kwargs["preemption_mode"]) + scheduler_group.add_argument("--num-scheduler-steps", + **scheduler_kwargs["num_scheduler_steps"]) + scheduler_group.add_argument( + "--multi-step-stream-outputs", + **scheduler_kwargs["multi_step_stream_outputs"]) + scheduler_group.add_argument("--scheduling-policy", + **scheduler_kwargs["policy"]) + scheduler_group.add_argument("--enable-chunked-prefill", + **scheduler_kwargs["enable_chunked_prefill"]) + scheduler_group.add_argument( + "--disable-chunked-mm-input", + **scheduler_kwargs["disable_chunked_mm_input"]) + scheduler_group.add_argument("--scheduler-cls", + **scheduler_kwargs["scheduler_cls"]) + scheduler_group.add_argument( + "--disable-hybrid-kv-cache-manager", + **scheduler_kwargs["disable_hybrid_kv_cache_manager"]) + + # vLLM arguments + vllm_kwargs = get_kwargs(VllmConfig) + vllm_group = parser.add_argument_group( + title="VllmConfig", + description=VllmConfig.__doc__, + ) + vllm_group.add_argument("--kv-transfer-config", + **vllm_kwargs["kv_transfer_config"]) + vllm_group.add_argument('--kv-events-config', + **vllm_kwargs["kv_events_config"]) + vllm_group.add_argument("--compilation-config", "-O", + **vllm_kwargs["compilation_config"]) + vllm_group.add_argument("--additional-config", + **vllm_kwargs["additional_config"]) + + # Other arguments + parser.add_argument('--use-v2-block-manager', + action='store_true', + default=True, + deprecated=True, + help='[DEPRECATED] block manager v1 has been ' + 'removed and SelfAttnBlockSpaceManager (i.e. ' + 'block manager v2) is now the default. ' + 'Setting this flag to True or False' + ' has no effect on vLLM behavior.') + parser.add_argument('--disable-log-stats', + action='store_true', + help='Disable logging statistics.') + + return parser diff --git a/vllm_mindspore/model_executor/layers/quantization/__init__.py b/vllm_mindspore/model_executor/layers/quantization/__init__.py index e69de29bb..b7e3556ba 100644 --- a/vllm_mindspore/model_executor/layers/quantization/__init__.py +++ b/vllm_mindspore/model_executor/layers/quantization/__init__.py @@ -0,0 +1,49 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Adapted from +# https://github.com/vllm-project/vllm/blob/v0.8.3/vllm/model_executor/layers/quantization/__init__.py +# +# Copyright 2025 Huawei Technologies Co., Ltd. +# Copyright 2024-2025 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) + +QUANTIZATION_METHODS: list[str] = ["ModelSlim"] + +# The customized quantization methods which will be added to this dict. +_CUSTOMIZED_METHOD_TO_QUANT_CONFIG: dict[str, type[QuantizationConfig]] = {} + + +def get_quantization_config(quantization: str) -> type[QuantizationConfig]: + if quantization not in QUANTIZATION_METHODS: + raise ValueError(f"Invalid quantization method: {quantization}") + + from vllm_mindspore.model_executor.layers.quantization.smooth_quant_modelslim import ( # noqa:E501 + SmoothQuantModelSlimConfig) + method_to_config: dict[str, type[QuantizationConfig]] = { + "ModelSlim": SmoothQuantModelSlimConfig + } + # Update the `method_to_config` with customized quantization methods. + method_to_config.update(_CUSTOMIZED_METHOD_TO_QUANT_CONFIG) + + return method_to_config[quantization] + + +__all__ = [ + "QuantizationConfig", + "get_quantization_config", + "QUANTIZATION_METHODS", +] diff --git a/vllm_mindspore/model_executor/layers/quantization/base_config.py b/vllm_mindspore/model_executor/layers/quantization/base_config.py index 37144a431..7834d1392 100644 --- a/vllm_mindspore/model_executor/layers/quantization/base_config.py +++ b/vllm_mindspore/model_executor/layers/quantization/base_config.py @@ -23,6 +23,8 @@ from abc import ABC, abstractmethod from typing import Any, Optional import mindspore as ms +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig as BaseQuantizationConfig) class QuantizeMethodBase(ABC): @@ -58,7 +60,7 @@ class QuantizeMethodBase(ABC): return -class QuantizationConfig(ABC): +class QuantizationConfig(BaseQuantizationConfig): """Base class for quantization configs.""" def __init__(self): diff --git a/vllm_mindspore/model_executor/model_loader/weight_utils.py b/vllm_mindspore/model_executor/model_loader/weight_utils.py index 6bf2dd4cd..bebd4d874 100644 --- a/vllm_mindspore/model_executor/model_loader/weight_utils.py +++ b/vllm_mindspore/model_executor/model_loader/weight_utils.py @@ -18,15 +18,29 @@ # See the License for the specific language governing permissions and # limitations under the License. +import glob +import json +import os from collections.abc import Generator from typing import Any +import huggingface_hub import mindspore as ms +from huggingface_hub import snapshot_download from mindspore import Parameter from safetensors import safe_open from tqdm.auto import tqdm +from vllm.config import LoadConfig from vllm.model_executor.model_loader.weight_utils import (_BAR_FORMAT, - enable_tqdm) + DisabledTqdm, + enable_tqdm, + get_lock) + +from vllm_mindspore.model_executor.layers.quantization import ( + get_quantization_config) +from vllm_mindspore.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm_mindspore.platforms.ascend import ModelConfig def split_loaded_weight(loaded_weight, shard_dim, start_idx, shard_size): @@ -78,3 +92,94 @@ def default_weight_loader(param: Parameter, loaded_weight: Any) -> None: """Default weight loader.""" loaded_weight = loaded_weight[:] param.set_data(ms.Tensor(loaded_weight, dtype=param.dtype)) + + +def get_quant_config(model_config: ModelConfig, + load_config: LoadConfig) -> QuantizationConfig: + quant_cls = get_quantization_config(model_config.quantization) + + # GGUF doesn't have config file + if model_config.quantization == "gguf": + return quant_cls.from_config({}) + + # Read the quantization config from the HF model config, if available. + hf_quant_config = getattr(model_config.hf_config, "quantization_config", + None) + # some vision model may keep quantization_config in their text_config + hf_text_config = getattr(model_config.hf_config, "text_config", None) + if hf_quant_config is None and hf_text_config is not None: + hf_quant_config = getattr(hf_text_config, "quantization_config", None) + if hf_quant_config is None: + # compressed-tensors uses a compressions_config + hf_quant_config = getattr(model_config.hf_config, "compression_config", + None) + if hf_quant_config is not None: + if os.path.isdir(model_config.model): + quant_config_file = os.path.join( + model_config.model, + quant_cls.get_config_filenames()[0]) + with open(quant_config_file) as f: + quant_config = json.load(f) + return quant_cls.from_config(hf_quant_config | quant_config) + + # In case of bitsandbytes/QLoRA, get quant config from the adapter model. + if model_config.quantization == "bitsandbytes": + if (not load_config.model_loader_extra_config + or "qlora_adapter_name_or_path" + not in load_config.model_loader_extra_config): + return quant_cls.from_config({"adapter_name_or_path": ""}) + model_name_or_path = load_config.model_loader_extra_config[ + "qlora_adapter_name_or_path"] + + else: + model_name_or_path = model_config.model + is_local = os.path.isdir(model_name_or_path) + if not is_local: + # Download the config files. + with get_lock(model_name_or_path, load_config.download_dir): + hf_folder = snapshot_download( + model_name_or_path, + revision=model_config.revision, + allow_patterns="*.json", + cache_dir=load_config.download_dir, + local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE, + tqdm_class=DisabledTqdm, + ) + else: + hf_folder = model_name_or_path + + possible_config_filenames = quant_cls.get_config_filenames() + + # If the quantization config is not found, use the default config. + if not possible_config_filenames: + return quant_cls() + + config_files = glob.glob(os.path.join(hf_folder, "*.json")) + + quant_config_files = [ + f for f in config_files if any( + f.endswith(x) for x in possible_config_filenames) + ] + if len(quant_config_files) == 0: + raise ValueError( + f"Cannot find the config file for {model_config.quantization}") + if len(quant_config_files) > 1: + raise ValueError( + f"Found multiple config files for {model_config.quantization}: " + f"{quant_config_files}") + + quant_config_file = quant_config_files[0] + with open(quant_config_file) as f: + config = json.load(f) + + if model_config.quantization == "bitsandbytes": + config["adapter_name_or_path"] = model_name_or_path + elif model_config.quantization == "modelopt": + if config["producer"]["name"] == "modelopt": + return quant_cls.from_config(config) + else: + raise ValueError( + f"Unsupported quantization config" + f" found for {model_config.quantization} in {f}.") + + return quant_cls.from_config(config) -- Gitee