mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-23 07:45:48 +08:00
fix DummyConfig in tests
Signed-off-by: Xingyu Liu <charlotteliu12x@gmail.com>
This commit is contained in:
parent
1cf506d89e
commit
c327dffce1
@ -14,6 +14,9 @@ from vllm.config.model import ModelConfig, ModelDType, RunnerOption
|
||||
from vllm.logprobs import Logprob, PromptLogprobs, SampleLogprobs
|
||||
from vllm.multimodal.processing import InputProcessingContext
|
||||
from vllm.tokenizers import cached_tokenizer_from_config
|
||||
from vllm.transformers_utils.model_arch_config_convertor import (
|
||||
ModelArchConfigConvertorBase,
|
||||
)
|
||||
|
||||
from .. import ci_envs
|
||||
from .registry import HF_EXAMPLE_MODELS
|
||||
@ -488,7 +491,7 @@ def dummy_hf_overrides(
|
||||
|
||||
# Only set MoE related config when the model has MoE layers.
|
||||
# Otherwise all models detected as MoE by _get_transformers_backend_cls.
|
||||
if ModelConfig.get_num_experts(DummyConfig) > 0:
|
||||
if ModelArchConfigConvertorBase.get_num_experts(text_config) > 0:
|
||||
update_dict.update(
|
||||
{
|
||||
"num_experts": num_experts,
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, final
|
||||
|
||||
import torch
|
||||
from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE
|
||||
@ -88,7 +88,9 @@ class ModelArchConfigConvertorBase:
|
||||
|
||||
return self.hf_text_config.num_attention_heads
|
||||
|
||||
def get_num_experts(self) -> int:
|
||||
@final
|
||||
@classmethod
|
||||
def get_num_experts(cls, hf_text_config: PretrainedConfig) -> int:
|
||||
"""Returns the number of experts in the model."""
|
||||
num_expert_names = [
|
||||
"num_experts", # Jamba
|
||||
@ -96,7 +98,7 @@ class ModelArchConfigConvertorBase:
|
||||
"n_routed_experts", # DeepSeek
|
||||
"num_local_experts", # Mixtral
|
||||
]
|
||||
num_experts = getattr_iter(self.hf_text_config, num_expert_names, 0)
|
||||
num_experts = getattr_iter(hf_text_config, num_expert_names, 0)
|
||||
if isinstance(num_experts, list):
|
||||
# Ernie VL's remote code uses list[int]...
|
||||
# The values are always the same so we just take the first one.
|
||||
@ -104,8 +106,11 @@ class ModelArchConfigConvertorBase:
|
||||
# Coerce to 0 if explicitly set to None
|
||||
return num_experts or 0
|
||||
|
||||
@final
|
||||
@classmethod
|
||||
def get_torch_dtype(cls, hf_config, model_id: str, revision: str | None):
|
||||
def get_torch_dtype(
|
||||
cls, hf_config: PretrainedConfig, model_id: str, revision: str | None
|
||||
):
|
||||
# NOTE: getattr(config, "dtype", torch.float32) is not correct
|
||||
# because config.dtype can be None.
|
||||
config_dtype = getattr(hf_config, "dtype", None)
|
||||
@ -139,6 +144,7 @@ class ModelArchConfigConvertorBase:
|
||||
|
||||
return config_dtype
|
||||
|
||||
@final
|
||||
@classmethod
|
||||
def _normalize_quantization_config(cls, config: PretrainedConfig):
|
||||
quant_cfg = getattr(config, "quantization_config", None)
|
||||
@ -171,6 +177,7 @@ class ModelArchConfigConvertorBase:
|
||||
|
||||
return quant_cfg
|
||||
|
||||
@final
|
||||
@classmethod
|
||||
def get_quantization_config(cls, hf_config: PretrainedConfig):
|
||||
quant_cfg = cls._normalize_quantization_config(hf_config)
|
||||
@ -258,7 +265,7 @@ class ModelArchConfigConvertorBase:
|
||||
head_size=self.get_head_size(),
|
||||
vocab_size=self.get_vocab_size(),
|
||||
total_num_kv_heads=self.get_total_num_kv_heads(),
|
||||
num_experts=self.get_num_experts(),
|
||||
num_experts=self.get_num_experts(self.hf_text_config),
|
||||
quantization_config=self.get_quantization_config(self.hf_config),
|
||||
torch_dtype=self.get_torch_dtype(self.hf_config, model_id, revision),
|
||||
is_multimodal_model=self.is_multimodal_model(),
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user