mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-23 20:54:36 +08:00
Fix some more Transformers nightly tests (#29872)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
parent
3ff5b53bc2
commit
6fc5841db1
@ -1801,7 +1801,10 @@ def run_tarsier2(questions: list[str], modality: str) -> ModelRequestData:
|
||||
engine_args = EngineArgs(
|
||||
model=model_name,
|
||||
max_model_len=4096,
|
||||
hf_overrides={"architectures": ["Tarsier2ForConditionalGeneration"]},
|
||||
hf_overrides={
|
||||
"architectures": ["Tarsier2ForConditionalGeneration"],
|
||||
"model_type": "tarsier2",
|
||||
},
|
||||
limit_mm_per_prompt={modality: 1},
|
||||
)
|
||||
|
||||
|
||||
@ -1222,7 +1222,10 @@ def load_tarsier2(question: str, image_urls: list[str]) -> ModelRequestData:
|
||||
trust_remote_code=True,
|
||||
max_model_len=32768,
|
||||
limit_mm_per_prompt={"image": len(image_urls)},
|
||||
hf_overrides={"architectures": ["Tarsier2ForConditionalGeneration"]},
|
||||
hf_overrides={
|
||||
"architectures": ["Tarsier2ForConditionalGeneration"],
|
||||
"model_type": "tarsier2",
|
||||
},
|
||||
)
|
||||
|
||||
prompt = (
|
||||
|
||||
@ -831,7 +831,10 @@ _MULTIMODAL_EXAMPLE_MODELS = {
|
||||
"TarsierForConditionalGeneration": _HfExamplesInfo("omni-research/Tarsier-7b"),
|
||||
"Tarsier2ForConditionalGeneration": _HfExamplesInfo(
|
||||
"omni-research/Tarsier2-Recap-7b",
|
||||
hf_overrides={"architectures": ["Tarsier2ForConditionalGeneration"]},
|
||||
hf_overrides={
|
||||
"architectures": ["Tarsier2ForConditionalGeneration"],
|
||||
"model_type": "tarsier2",
|
||||
},
|
||||
),
|
||||
"VoxtralForConditionalGeneration": _HfExamplesInfo(
|
||||
"mistralai/Voxtral-Mini-3B-2507",
|
||||
|
||||
@ -1576,15 +1576,6 @@ class Tarsier2ForConditionalGeneration(Qwen2VLForConditionalGeneration):
|
||||
}
|
||||
)
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
# Tarsier2 uses llava as model_type, which will create a Qwen2VLConfig
|
||||
# as text_config, we need to reconstruct Qwen2VLConfig from LlavaConfig.
|
||||
config = vllm_config.model_config.hf_config
|
||||
qwen2vl_config = config.text_config
|
||||
qwen2vl_config.architectures = config.architectures
|
||||
vllm_config.model_config.hf_config = qwen2vl_config
|
||||
super().__init__(vllm_config=vllm_config, prefix=prefix)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||
skip_prefixes = []
|
||||
if self.visual is None:
|
||||
|
||||
@ -14,13 +14,19 @@ if TYPE_CHECKING:
|
||||
)
|
||||
from mistral_common.tokens.tokenizers.tekken import Tekkenizer
|
||||
from transformers import BatchEncoding
|
||||
from transformers.tokenization_mistral_common import (
|
||||
MistralCommonTokenizer as TransformersMistralTokenizer,
|
||||
)
|
||||
|
||||
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
|
||||
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
|
||||
|
||||
try:
|
||||
# Transformers v5
|
||||
from transformers.tokenization_mistral_common import MistralCommonBackend
|
||||
except ImportError:
|
||||
# Transformers v4
|
||||
from transformers.tokenization_mistral_common import (
|
||||
MistralCommonTokenizer as MistralCommonBackend,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@ -208,11 +214,17 @@ class MistralTokenizer(TokenizerLike):
|
||||
**kwargs,
|
||||
) -> "MistralTokenizer":
|
||||
from mistral_common.protocol.instruct.validator import ValidationMode
|
||||
from transformers.tokenization_mistral_common import (
|
||||
MistralCommonTokenizer as TransformersMistralTokenizer,
|
||||
)
|
||||
|
||||
tokenizer = TransformersMistralTokenizer.from_pretrained(
|
||||
try:
|
||||
# Transformers v5
|
||||
from transformers.tokenization_mistral_common import MistralCommonBackend
|
||||
except ImportError:
|
||||
# Transformers v4
|
||||
from transformers.tokenization_mistral_common import (
|
||||
MistralCommonTokenizer as MistralCommonBackend,
|
||||
)
|
||||
|
||||
tokenizer = MistralCommonBackend.from_pretrained(
|
||||
path_or_repo_id,
|
||||
*args,
|
||||
mode=ValidationMode.test,
|
||||
@ -223,7 +235,7 @@ class MistralTokenizer(TokenizerLike):
|
||||
|
||||
return cls(tokenizer)
|
||||
|
||||
def __init__(self, tokenizer: "TransformersMistralTokenizer") -> None:
|
||||
def __init__(self, tokenizer: "MistralCommonBackend") -> None:
|
||||
super().__init__()
|
||||
|
||||
from mistral_common.protocol.instruct.validator import ValidationMode
|
||||
|
||||
@ -89,6 +89,7 @@ _CONFIG_REGISTRY: dict[str, type[PretrainedConfig]] = LazyConfigDict(
|
||||
step3_text="Step3TextConfig",
|
||||
qwen3_next="Qwen3NextConfig",
|
||||
lfm2_moe="Lfm2MoeConfig",
|
||||
tarsier2="Tarsier2Config",
|
||||
)
|
||||
|
||||
_CONFIG_ATTRS_MAPPING: dict[str, str] = {
|
||||
@ -127,6 +128,9 @@ class HFConfigParser(ConfigParserBase):
|
||||
if config_dict.get("speculators_config") is not None
|
||||
else model_type
|
||||
)
|
||||
# Allow hf_overrides to override model_type before checking _CONFIG_REGISTRY
|
||||
if (hf_overrides := kwargs.pop("hf_overrides", None)) is not None:
|
||||
model_type = hf_overrides.get("model_type", model_type)
|
||||
|
||||
if model_type in _CONFIG_REGISTRY:
|
||||
config_class = _CONFIG_REGISTRY[model_type]
|
||||
@ -310,7 +314,7 @@ def patch_rope_parameters(config: PretrainedConfig) -> None:
|
||||
config.rope_parameters["rope_theta"] = rope_theta
|
||||
|
||||
# No RoPE parameters to patch
|
||||
if not hasattr(config, "rope_parameters"):
|
||||
if getattr(config, "rope_parameters", None) is None:
|
||||
return
|
||||
|
||||
# Add original_max_position_embeddings if present
|
||||
@ -351,7 +355,10 @@ def patch_rope_parameters_dict(rope_parameters: dict[str, Any]) -> None:
|
||||
rope_parameters["rope_type"] = "longrope"
|
||||
logger.warning("Replacing legacy rope_type 'su' with 'longrope'")
|
||||
elif rope_parameters["rope_type"] == "mrope":
|
||||
assert "mrope_section" in rope_parameters
|
||||
if "mrope_section" not in rope_parameters:
|
||||
raise ValueError(
|
||||
"Legacy rope_type 'mrope' requires 'mrope_section' in rope_parameters"
|
||||
)
|
||||
rope_parameters["rope_type"] = "default"
|
||||
logger.warning("Replacing legacy rope_type 'mrope' with 'default'")
|
||||
|
||||
@ -584,6 +591,7 @@ def get_config(
|
||||
trust_remote_code=trust_remote_code,
|
||||
revision=revision,
|
||||
code_revision=code_revision,
|
||||
hf_overrides=hf_overrides_kw,
|
||||
**kwargs,
|
||||
)
|
||||
# Special architecture mapping check for GGUF models
|
||||
@ -915,11 +923,13 @@ def get_hf_text_config(config: PretrainedConfig):
|
||||
"""
|
||||
text_config = config.get_text_config()
|
||||
|
||||
if text_config is not config:
|
||||
# The code operates under the assumption that text_config should have
|
||||
# `num_attention_heads` (among others). Assert here to fail early
|
||||
# if transformers config doesn't align with this assumption.
|
||||
assert hasattr(text_config, "num_attention_heads")
|
||||
if text_config is not config and not hasattr(text_config, "num_attention_heads"):
|
||||
raise ValueError(
|
||||
"The text_config extracted from the model config does not have "
|
||||
"`num_attention_heads` attribute. This indicates a mismatch "
|
||||
"between the model config and vLLM's expectations. Please "
|
||||
"ensure that the model config is compatible with vLLM."
|
||||
)
|
||||
|
||||
return text_config
|
||||
|
||||
|
||||
@ -48,6 +48,7 @@ from vllm.transformers_utils.configs.step3_vl import (
|
||||
Step3VisionEncoderConfig,
|
||||
Step3VLConfig,
|
||||
)
|
||||
from vllm.transformers_utils.configs.tarsier2 import Tarsier2Config
|
||||
from vllm.transformers_utils.configs.ultravox import UltravoxConfig
|
||||
|
||||
__all__ = [
|
||||
@ -81,4 +82,5 @@ __all__ = [
|
||||
"Step3VisionEncoderConfig",
|
||||
"Step3TextConfig",
|
||||
"Qwen3NextConfig",
|
||||
"Tarsier2Config",
|
||||
]
|
||||
|
||||
24
vllm/transformers_utils/configs/tarsier2.py
Normal file
24
vllm/transformers_utils/configs/tarsier2.py
Normal file
@ -0,0 +1,24 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from transformers import Qwen2VLConfig
|
||||
|
||||
|
||||
class Tarsier2Config(Qwen2VLConfig):
|
||||
"""
|
||||
Tarsier2's config.json is written such that AutoConfig.from_pretrained will create
|
||||
a deeply nested config consisting of:
|
||||
|
||||
- LlavaConfig
|
||||
- Qwen2VLConfig
|
||||
- Qwen2VLTextConfig
|
||||
- Qwen2VLVisionConfig
|
||||
- Qwen2VLConfig
|
||||
- Qwen2VLTextConfig
|
||||
- Qwen2VLVisionConfig
|
||||
|
||||
When it should really just be a single Qwen2VLConfig.
|
||||
|
||||
This class is a hack to stop AutoConfig from creating the nested config structure.
|
||||
"""
|
||||
|
||||
model_type = "tarsier2"
|
||||
Loading…
x
Reference in New Issue
Block a user