mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-31 12:37:09 +08:00
fix DummyConfig in tests
Signed-off-by: Xingyu Liu <charlotteliu12x@gmail.com>
This commit is contained in:
parent
1cf506d89e
commit
c327dffce1
@ -14,6 +14,9 @@ from vllm.config.model import ModelConfig, ModelDType, RunnerOption
|
|||||||
from vllm.logprobs import Logprob, PromptLogprobs, SampleLogprobs
|
from vllm.logprobs import Logprob, PromptLogprobs, SampleLogprobs
|
||||||
from vllm.multimodal.processing import InputProcessingContext
|
from vllm.multimodal.processing import InputProcessingContext
|
||||||
from vllm.tokenizers import cached_tokenizer_from_config
|
from vllm.tokenizers import cached_tokenizer_from_config
|
||||||
|
from vllm.transformers_utils.model_arch_config_convertor import (
|
||||||
|
ModelArchConfigConvertorBase,
|
||||||
|
)
|
||||||
|
|
||||||
from .. import ci_envs
|
from .. import ci_envs
|
||||||
from .registry import HF_EXAMPLE_MODELS
|
from .registry import HF_EXAMPLE_MODELS
|
||||||
@ -488,7 +491,7 @@ def dummy_hf_overrides(
|
|||||||
|
|
||||||
# Only set MoE related config when the model has MoE layers.
|
# Only set MoE related config when the model has MoE layers.
|
||||||
# Otherwise all models detected as MoE by _get_transformers_backend_cls.
|
# Otherwise all models detected as MoE by _get_transformers_backend_cls.
|
||||||
if ModelConfig.get_num_experts(DummyConfig) > 0:
|
if ModelArchConfigConvertorBase.get_num_experts(text_config) > 0:
|
||||||
update_dict.update(
|
update_dict.update(
|
||||||
{
|
{
|
||||||
"num_experts": num_experts,
|
"num_experts": num_experts,
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING, final
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE
|
from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE
|
||||||
@ -88,7 +88,9 @@ class ModelArchConfigConvertorBase:
|
|||||||
|
|
||||||
return self.hf_text_config.num_attention_heads
|
return self.hf_text_config.num_attention_heads
|
||||||
|
|
||||||
def get_num_experts(self) -> int:
|
@final
|
||||||
|
@classmethod
|
||||||
|
def get_num_experts(cls, hf_text_config: PretrainedConfig) -> int:
|
||||||
"""Returns the number of experts in the model."""
|
"""Returns the number of experts in the model."""
|
||||||
num_expert_names = [
|
num_expert_names = [
|
||||||
"num_experts", # Jamba
|
"num_experts", # Jamba
|
||||||
@ -96,7 +98,7 @@ class ModelArchConfigConvertorBase:
|
|||||||
"n_routed_experts", # DeepSeek
|
"n_routed_experts", # DeepSeek
|
||||||
"num_local_experts", # Mixtral
|
"num_local_experts", # Mixtral
|
||||||
]
|
]
|
||||||
num_experts = getattr_iter(self.hf_text_config, num_expert_names, 0)
|
num_experts = getattr_iter(hf_text_config, num_expert_names, 0)
|
||||||
if isinstance(num_experts, list):
|
if isinstance(num_experts, list):
|
||||||
# Ernie VL's remote code uses list[int]...
|
# Ernie VL's remote code uses list[int]...
|
||||||
# The values are always the same so we just take the first one.
|
# The values are always the same so we just take the first one.
|
||||||
@ -104,8 +106,11 @@ class ModelArchConfigConvertorBase:
|
|||||||
# Coerce to 0 if explicitly set to None
|
# Coerce to 0 if explicitly set to None
|
||||||
return num_experts or 0
|
return num_experts or 0
|
||||||
|
|
||||||
|
@final
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_torch_dtype(cls, hf_config, model_id: str, revision: str | None):
|
def get_torch_dtype(
|
||||||
|
cls, hf_config: PretrainedConfig, model_id: str, revision: str | None
|
||||||
|
):
|
||||||
# NOTE: getattr(config, "dtype", torch.float32) is not correct
|
# NOTE: getattr(config, "dtype", torch.float32) is not correct
|
||||||
# because config.dtype can be None.
|
# because config.dtype can be None.
|
||||||
config_dtype = getattr(hf_config, "dtype", None)
|
config_dtype = getattr(hf_config, "dtype", None)
|
||||||
@ -139,6 +144,7 @@ class ModelArchConfigConvertorBase:
|
|||||||
|
|
||||||
return config_dtype
|
return config_dtype
|
||||||
|
|
||||||
|
@final
|
||||||
@classmethod
|
@classmethod
|
||||||
def _normalize_quantization_config(cls, config: PretrainedConfig):
|
def _normalize_quantization_config(cls, config: PretrainedConfig):
|
||||||
quant_cfg = getattr(config, "quantization_config", None)
|
quant_cfg = getattr(config, "quantization_config", None)
|
||||||
@ -171,6 +177,7 @@ class ModelArchConfigConvertorBase:
|
|||||||
|
|
||||||
return quant_cfg
|
return quant_cfg
|
||||||
|
|
||||||
|
@final
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_quantization_config(cls, hf_config: PretrainedConfig):
|
def get_quantization_config(cls, hf_config: PretrainedConfig):
|
||||||
quant_cfg = cls._normalize_quantization_config(hf_config)
|
quant_cfg = cls._normalize_quantization_config(hf_config)
|
||||||
@ -258,7 +265,7 @@ class ModelArchConfigConvertorBase:
|
|||||||
head_size=self.get_head_size(),
|
head_size=self.get_head_size(),
|
||||||
vocab_size=self.get_vocab_size(),
|
vocab_size=self.get_vocab_size(),
|
||||||
total_num_kv_heads=self.get_total_num_kv_heads(),
|
total_num_kv_heads=self.get_total_num_kv_heads(),
|
||||||
num_experts=self.get_num_experts(),
|
num_experts=self.get_num_experts(self.hf_text_config),
|
||||||
quantization_config=self.get_quantization_config(self.hf_config),
|
quantization_config=self.get_quantization_config(self.hf_config),
|
||||||
torch_dtype=self.get_torch_dtype(self.hf_config, model_id, revision),
|
torch_dtype=self.get_torch_dtype(self.hf_config, model_id, revision),
|
||||||
is_multimodal_model=self.is_multimodal_model(),
|
is_multimodal_model=self.is_multimodal_model(),
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user