[Bugfix] Fix unable to load some models (#10312)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2024-11-15 08:55:54 +08:00 committed by GitHub
parent 11cd1ae6ad
commit 972112d82f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 340 additions and 59 deletions

View File

@ -313,14 +313,15 @@ steps:
##### models test #####
- label: Basic Models Test # 10min
- label: Basic Models Test # 30min
source_file_dependencies:
- vllm/
- tests/models
commands:
- pip install -e ./plugins/vllm_add_dummy_model
- pytest -v -s models/test_oot_registration.py # it needs a clean process
- pytest -v -s models/*.py --ignore=models/test_oot_registration.py
- pytest -v -s models/test_registry.py
- pytest -v -s models/test_initialization.py
- label: Decoder-only Language Models Test (Standard) # 18min
#mirror_hardwares: [amd]

View File

@ -166,14 +166,14 @@ TEXT_GENERATION_MODELS = {
"mistralai/Mixtral-8x7B-Instruct-v0.1": PPTestSettings.fast(tp_base=4),
"mosaicml/mpt-7b": PPTestSettings.fast(),
"nvidia/Minitron-8B-Base": PPTestSettings.fast(),
"allenai/OLMoE-1B-7B-0924-Instruct": PPTestSettings.fast(),
"allenai/OLMo-1B-hf": PPTestSettings.fast(),
"allenai/OLMoE-1B-7B-0924-Instruct": PPTestSettings.fast(),
"facebook/opt-iml-max-1.3b": PPTestSettings.fast(),
"OrionStarAI/Orion-14B-Chat": PPTestSettings.fast(trust_remote_code=True),
"microsoft/phi-2": PPTestSettings.fast(),
"microsoft/Phi-3.5-MoE-instruct": PPTestSettings.detailed(trust_remote_code=True, multi_node_only=True, load_format="dummy", hf_overrides='{"num_hidden_layers": 4, "hidden_size": 512, "intermediate_size": 800, "num_attention_heads": 4, "num_key_value_heads": 1}'), # noqa: E501
"microsoft/Phi-3-small-8k-instruct": PPTestSettings.fast(trust_remote_code=True), # noqa: E501
"adept/persimmon-8b-chat": PPTestSettings.fast(),
"microsoft/phi-2": PPTestSettings.fast(),
"microsoft/Phi-3-small-8k-instruct": PPTestSettings.fast(trust_remote_code=True), # noqa: E501
"microsoft/Phi-3.5-MoE-instruct": PPTestSettings.detailed(trust_remote_code=True, multi_node_only=True, load_format="dummy", hf_overrides='{"num_hidden_layers": 4, "hidden_size": 512, "intermediate_size": 800, "num_attention_heads": 4, "num_key_value_heads": 1}'), # noqa: E501
"Qwen/Qwen-7B-Chat": PPTestSettings.fast(trust_remote_code=True),
"Qwen/Qwen2-7B-Instruct": PPTestSettings.fast(),
"Qwen/Qwen1.5-MoE-A2.7B-Chat": PPTestSettings.fast(),

212
tests/models/registry.py Normal file
View File

@ -0,0 +1,212 @@
from dataclasses import dataclass, field
from typing import AbstractSet, Mapping, Optional
@dataclass(frozen=True)
class _HfExamplesInfo:
default: str
"""The default model to use for testing this architecture."""
extras: Mapping[str, str] = field(default_factory=dict)
"""Extra models to use for testing this architecture."""
tokenizer: Optional[str] = None
"""Set the tokenizer to load for this architecture."""
tokenizer_mode: str = "auto"
"""Set the tokenizer type for this architecture."""
speculative_model: Optional[str] = None
"""
The default model to use for testing this architecture, which is only used
for speculative decoding.
"""
is_available_online: bool = True
"""
Set this to ``False`` if the name of this architecture no longer exists on
the HF repo. To maintain backwards compatibility, we have not removed them
from the main model registry, so without this flag the registry tests will
fail.
"""
trust_remote_code: bool = False
"""The ``trust_remote_code`` level required to load the model."""
# yapf: disable
_TEXT_GENERATION_EXAMPLE_MODELS = {
# [Decoder-only]
"AquilaModel": _HfExamplesInfo("BAAI/AquilaChat-7B",
trust_remote_code=True),
"AquilaForCausalLM": _HfExamplesInfo("BAAI/AquilaChat2-7B",
trust_remote_code=True),
"ArcticForCausalLM": _HfExamplesInfo("Snowflake/snowflake-arctic-instruct",
trust_remote_code=True),
"BaiChuanForCausalLM": _HfExamplesInfo("baichuan-inc/Baichuan-7B",
trust_remote_code=True),
"BaichuanForCausalLM": _HfExamplesInfo("baichuan-inc/Baichuan2-7B-chat",
trust_remote_code=True),
"BloomForCausalLM": _HfExamplesInfo("bigscience/bloomz-1b1"),
# ChatGLMModel supports multimodal
"CohereForCausalLM": _HfExamplesInfo("CohereForAI/c4ai-command-r-v01",
trust_remote_code=True),
"DbrxForCausalLM": _HfExamplesInfo("databricks/dbrx-instruct"),
"DeciLMForCausalLM": _HfExamplesInfo("Deci/DeciLM-7B-instruct",
trust_remote_code=True),
"DeepseekForCausalLM": _HfExamplesInfo("deepseek-ai/deepseek-llm-7b-chat"),
"DeepseekV2ForCausalLM": _HfExamplesInfo("deepseek-ai/DeepSeek-V2-Lite-Chat", # noqa: E501
trust_remote_code=True),
"ExaoneForCausalLM": _HfExamplesInfo("LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct"), # noqa: E501
"FalconForCausalLM": _HfExamplesInfo("tiiuae/falcon-7b"),
"GemmaForCausalLM": _HfExamplesInfo("google/gemma-2b"),
"Gemma2ForCausalLM": _HfExamplesInfo("google/gemma-2-9b"),
"GPT2LMHeadModel": _HfExamplesInfo("gpt2"),
"GPTBigCodeForCausalLM": _HfExamplesInfo("bigcode/starcoder"),
"GPTJForCausalLM": _HfExamplesInfo("EleutherAI/gpt-j-6b"),
"GPTNeoXForCausalLM": _HfExamplesInfo("EleutherAI/pythia-160m"),
"GraniteForCausalLM": _HfExamplesInfo("ibm/PowerLM-3b"),
"GraniteMoeForCausalLM": _HfExamplesInfo("ibm/PowerMoE-3b"),
"InternLMForCausalLM": _HfExamplesInfo("internlm/internlm-chat-7b",
trust_remote_code=True),
"InternLM2ForCausalLM": _HfExamplesInfo("internlm/internlm2-chat-7b",
trust_remote_code=True),
"InternLM2VEForCausalLM": _HfExamplesInfo("OpenGVLab/Mono-InternVL-2B",
trust_remote_code=True),
"JAISLMHeadModel": _HfExamplesInfo("inceptionai/jais-13b-chat"),
"JambaForCausalLM": _HfExamplesInfo("ai21labs/AI21-Jamba-1.5-Mini"),
"LlamaForCausalLM": _HfExamplesInfo("meta-llama/Meta-Llama-3-8B"),
"LLaMAForCausalLM": _HfExamplesInfo("decapoda-research/llama-7b-hf",
is_available_online=False),
"MambaForCausalLM": _HfExamplesInfo("state-spaces/mamba-130m-hf"),
"FalconMambaForCausalLM": _HfExamplesInfo("tiiuae/falcon-mamba-7b-instruct"), # noqa: E501
"MiniCPMForCausalLM": _HfExamplesInfo("openbmb/MiniCPM-2B-sft-bf16",
trust_remote_code=True),
"MiniCPM3ForCausalLM": _HfExamplesInfo("openbmb/MiniCPM3-4B",
trust_remote_code=True),
"MistralForCausalLM": _HfExamplesInfo("mistralai/Mistral-7B-Instruct-v0.1"),
"MixtralForCausalLM": _HfExamplesInfo("mistralai/Mixtral-8x7B-Instruct-v0.1"), # noqa: E501
"QuantMixtralForCausalLM": _HfExamplesInfo("mistral-community/Mixtral-8x22B-v0.1-AWQ"), # noqa: E501
"MptForCausalLM": _HfExamplesInfo("mpt", is_available_online=False),
"MPTForCausalLM": _HfExamplesInfo("mosaicml/mpt-7b"),
"NemotronForCausalLM": _HfExamplesInfo("nvidia/Minitron-8B-Base"),
"OlmoForCausalLM": _HfExamplesInfo("allenai/OLMo-1B-hf"),
"OlmoeForCausalLM": _HfExamplesInfo("allenai/OLMoE-1B-7B-0924-Instruct"),
"OPTForCausalLM": _HfExamplesInfo("facebook/opt-iml-max-1.3b"),
"OrionForCausalLM": _HfExamplesInfo("OrionStarAI/Orion-14B-Chat",
trust_remote_code=True),
"PersimmonForCausalLM": _HfExamplesInfo("adept/persimmon-8b-chat"),
"PhiForCausalLM": _HfExamplesInfo("microsoft/phi-2"),
"Phi3ForCausalLM": _HfExamplesInfo("microsoft/Phi-3-mini-4k-instruct"),
"Phi3SmallForCausalLM": _HfExamplesInfo("microsoft/Phi-3-small-8k-instruct",
trust_remote_code=True),
"PhiMoEForCausalLM": _HfExamplesInfo("microsoft/Phi-3.5-MoE-instruct",
trust_remote_code=True),
# QWenLMHeadModel supports multimodal
"Qwen2ForCausalLM": _HfExamplesInfo("Qwen/Qwen2-7B-Instruct"),
"Qwen2MoeForCausalLM": _HfExamplesInfo("Qwen/Qwen1.5-MoE-A2.7B-Chat"),
"RWForCausalLM": _HfExamplesInfo("tiiuae/falcon-40b",
is_available_online=False),
"StableLMEpochForCausalLM": _HfExamplesInfo("stabilityai/stablelm-zephyr-3b", # noqa: E501
is_available_online=False),
"StableLmForCausalLM": _HfExamplesInfo("stabilityai/stablelm-3b-4e1t"),
"Starcoder2ForCausalLM": _HfExamplesInfo("bigcode/starcoder2-3b"),
"SolarForCausalLM": _HfExamplesInfo("upstage/solar-pro-preview-instruct"),
"XverseForCausalLM": _HfExamplesInfo("xverse/XVERSE-7B-Chat",
is_available_online=False,
trust_remote_code=True),
# [Encoder-decoder]
"BartModel": _HfExamplesInfo("facebook/bart-base"),
"BartForConditionalGeneration": _HfExamplesInfo("facebook/bart-large-cnn"),
# Florence-2 uses BartFastTokenizer which can't be loaded from AutoTokenizer
# Therefore, we borrow the BartTokenizer from the original Bart model
"Florence2ForConditionalGeneration": _HfExamplesInfo("microsoft/Florence-2-base", # noqa: E501
tokenizer="facebook/bart-base",
trust_remote_code=True), # noqa: E501
}
_EMBEDDING_EXAMPLE_MODELS = {
# [Text-only]
"BertModel": _HfExamplesInfo("BAAI/bge-base-en-v1.5"),
"Gemma2Model": _HfExamplesInfo("BAAI/bge-multilingual-gemma2"),
"MistralModel": _HfExamplesInfo("intfloat/e5-mistral-7b-instruct"),
"Qwen2ForRewardModel": _HfExamplesInfo("Qwen/Qwen2.5-Math-RM-72B"),
"Qwen2ForSequenceClassification": _HfExamplesInfo("jason9693/Qwen2.5-1.5B-apeach"), # noqa: E501
# [Multimodal]
"LlavaNextForConditionalGeneration": _HfExamplesInfo("royokong/e5-v"),
"Phi3VForCausalLM": _HfExamplesInfo("TIGER-Lab/VLM2Vec-Full",
trust_remote_code=True),
"Qwen2VLForConditionalGeneration": _HfExamplesInfo("MrLight/dse-qwen2-2b-mrl-v1"), # noqa: E501
}
_MULTIMODAL_EXAMPLE_MODELS = {
# [Decoder-only]
"Blip2ForConditionalGeneration": _HfExamplesInfo("Salesforce/blip2-opt-2.7b"), # noqa: E501
"ChameleonForConditionalGeneration": _HfExamplesInfo("facebook/chameleon-7b"), # noqa: E501
"ChatGLMModel": _HfExamplesInfo("THUDM/glm-4v-9b",
extras={"text_only": "THUDM/chatglm3-6b"},
trust_remote_code=True),
"ChatGLMForConditionalGeneration": _HfExamplesInfo("chatglm2-6b",
is_available_online=False),
"FuyuForCausalLM": _HfExamplesInfo("adept/fuyu-8b"),
"H2OVLChatModel": _HfExamplesInfo("h2oai/h2ovl-mississippi-800m"),
"InternVLChatModel": _HfExamplesInfo("OpenGVLab/InternVL2-1B",
trust_remote_code=True),
"Idefics3ForConditionalGeneration": _HfExamplesInfo("HuggingFaceM4/Idefics3-8B-Llama3"), # noqa: E501
"LlavaForConditionalGeneration": _HfExamplesInfo("llava-hf/llava-1.5-7b-hf",
extras={"mistral": "mistral-community/pixtral-12b"}), # noqa: E501
"LlavaNextForConditionalGeneration": _HfExamplesInfo("llava-hf/llava-v1.6-mistral-7b-hf"), # noqa: E501
"LlavaNextVideoForConditionalGeneration": _HfExamplesInfo("llava-hf/LLaVA-NeXT-Video-7B-hf"), # noqa: E501
"LlavaOnevisionForConditionalGeneration": _HfExamplesInfo("llava-hf/llava-onevision-qwen2-0.5b-ov-hf"), # noqa: E501
"MiniCPMV": _HfExamplesInfo("openbmb/MiniCPM-Llama3-V-2_5",
trust_remote_code=True),
"MolmoForCausalLM": _HfExamplesInfo("allenai/Molmo-7B-D-0924",
trust_remote_code=True),
"NVLM_D": _HfExamplesInfo("nvidia/NVLM-D-72B",
trust_remote_code=True),
"PaliGemmaForConditionalGeneration": _HfExamplesInfo("google/paligemma-3b-pt-224"), # noqa: E501
"Phi3VForCausalLM": _HfExamplesInfo("microsoft/Phi-3-vision-128k-instruct",
trust_remote_code=True),
"PixtralForConditionalGeneration": _HfExamplesInfo("mistralai/Pixtral-12B-2409", # noqa: E501
tokenizer_mode="mistral"),
"QWenLMHeadModel": _HfExamplesInfo("Qwen/Qwen-VL-Chat",
extras={"text_only": "Qwen/Qwen-7B-Chat"}, # noqa: E501
trust_remote_code=True),
"Qwen2AudioForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2-Audio-7B-Instruct"), # noqa: E501
"Qwen2VLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2-VL-2B-Instruct"), # noqa: E501
"UltravoxModel": _HfExamplesInfo("fixie-ai/ultravox-v0_3"),
# [Encoder-decoder]
"MllamaForConditionalGeneration": _HfExamplesInfo("meta-llama/Llama-3.2-11B-Vision-Instruct"), # noqa: E501
}
_SPECULATIVE_DECODING_EXAMPLE_MODELS = {
"EAGLEModel": _HfExamplesInfo("JackFram/llama-68m",
speculative_model="abhigoyal/vllm-eagle-llama-68m-random"), # noqa: E501
"MedusaModel": _HfExamplesInfo("JackFram/llama-68m",
speculative_model="abhigoyal/vllm-medusa-llama-68m-random"), # noqa: E501
"MLPSpeculatorPreTrainedModel": _HfExamplesInfo("JackFram/llama-160m",
speculative_model="ibm-fms/llama-160m-accelerator"), # noqa: E501
}
_EXAMPLE_MODELS = {
**_TEXT_GENERATION_EXAMPLE_MODELS,
**_EMBEDDING_EXAMPLE_MODELS,
**_MULTIMODAL_EXAMPLE_MODELS,
**_SPECULATIVE_DECODING_EXAMPLE_MODELS,
}
class HfExampleModels:
def __init__(self, hf_models: Mapping[str, _HfExamplesInfo]) -> None:
super().__init__()
self.hf_models = hf_models
def get_supported_archs(self) -> AbstractSet[str]:
return self.hf_models.keys()
def get_hf_info(self, model_arch: str) -> _HfExamplesInfo:
return self.hf_models[model_arch]
HF_EXAMPLE_MODELS = HfExampleModels(_EXAMPLE_MODELS)

View File

@ -0,0 +1,55 @@
from unittest.mock import patch
import pytest
import transformers
from transformers import PretrainedConfig
from vllm import LLM
from .registry import HF_EXAMPLE_MODELS
@pytest.mark.parametrize("model_arch", HF_EXAMPLE_MODELS.get_supported_archs())
def test_can_initialize(model_arch):
if (model_arch == "Idefics3ForConditionalGeneration"
and transformers.__version__ < "4.46.0"):
pytest.skip(reason="Model introduced in HF >= 4.46.0")
model_info = HF_EXAMPLE_MODELS.get_hf_info(model_arch)
if not model_info.is_available_online:
pytest.skip("Model is not available online")
# Avoid OOM
def hf_overrides(hf_config: PretrainedConfig) -> PretrainedConfig:
if hasattr(hf_config, "text_config"):
text_config: PretrainedConfig = hf_config.text_config
else:
text_config = hf_config
text_config.update({
"num_layers": 1,
"num_hidden_layers": 1,
"num_experts": 2,
"num_experts_per_tok": 2,
"num_local_experts": 2,
})
return hf_config
# Avoid calling model.forward()
def _initialize_kv_caches(self) -> None:
self.cache_config.num_gpu_blocks = 0
self.cache_config.num_cpu_blocks = 0
with patch.object(LLM.get_engine_class(), "_initialize_kv_caches",
_initialize_kv_caches):
LLM(
model_info.default,
tokenizer=model_info.tokenizer,
tokenizer_mode=model_info.tokenizer_mode,
speculative_model=model_info.speculative_model,
num_speculative_tokens=1 if model_info.speculative_model else None,
trust_remote_code=model_info.trust_remote_code,
load_format="dummy",
hf_overrides=hf_overrides,
)

View File

@ -14,6 +14,7 @@ from vllm.model_executor.models.registry import (_EMBEDDING_MODELS,
from vllm.platforms import current_platform
from ..utils import fork_new_process_for_each_test
from .registry import HF_EXAMPLE_MODELS
@pytest.mark.parametrize("model_arch", ModelRegistry.get_supported_archs())
@ -73,3 +74,12 @@ def test_registry_is_pp(model_arch, is_pp, init_cuda):
"This model no longer initializes CUDA on import. "
"Please test using a different one.",
stacklevel=2)
def test_hf_registry_coverage():
untested_archs = (HF_EXAMPLE_MODELS.get_supported_archs() -
set(ModelRegistry.get_supported_archs()))
assert not untested_archs, (
"Please add the following architectures to "
f"`tests/models/registry.py`: {untested_archs}")

View File

@ -3,8 +3,8 @@ import enum
import json
import warnings
from dataclasses import dataclass, field, replace
from typing import (TYPE_CHECKING, Any, ClassVar, Dict, Final, List, Literal,
Mapping, Optional, Set, Tuple, Type, Union)
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Dict, Final, List,
Literal, Mapping, Optional, Set, Tuple, Type, Union)
import torch
from transformers import PretrainedConfig
@ -20,7 +20,7 @@ from vllm.transformers_utils.config import (
get_hf_text_config, get_pooling_config,
get_sentence_transformer_tokenizer_config, is_encoder_decoder, uses_mrope)
from vllm.utils import (GiB_bytes, cuda_device_count_stateless, get_cpu_memory,
print_warning_once)
identity, print_warning_once)
if TYPE_CHECKING:
from ray.util.placement_group import PlacementGroup
@ -44,6 +44,9 @@ TaskOption = Literal["auto", "generate", "embedding"]
# "draft" is only used internally for speculative decoding
_Task = Literal["generate", "embedding", "draft"]
HfOverrides = Union[Dict[str, Any], Callable[[PretrainedConfig],
PretrainedConfig]]
class ModelConfig:
"""Configuration for the model.
@ -115,7 +118,9 @@ class ModelConfig:
can not be gathered from the vllm arguments.
config_format: The config format which shall be loaded.
Defaults to 'auto' which defaults to 'hf'.
hf_overrides: Arguments to be forwarded to the HuggingFace config.
hf_overrides: If a dictionary, contains arguments to be forwarded to the
HuggingFace config. If a callable, it is called to update the
HuggingFace config.
mm_processor_kwargs: Arguments to be forwarded to the model's processor
for multi-modal data, e.g., image processor.
pooling_type: Used to configure the pooling method in the embedding
@ -164,7 +169,7 @@ class ModelConfig:
override_neuron_config: Optional[Dict[str, Any]] = None,
config_format: ConfigFormat = ConfigFormat.AUTO,
chat_template_text_format: str = "string",
hf_overrides: Optional[Dict[str, Any]] = None,
hf_overrides: Optional[HfOverrides] = None,
mm_processor_kwargs: Optional[Dict[str, Any]] = None,
pooling_type: Optional[str] = None,
pooling_norm: Optional[bool] = None,
@ -182,15 +187,23 @@ class ModelConfig:
if hf_overrides is None:
hf_overrides = {}
if callable(hf_overrides):
hf_overrides_kw = {}
hf_overrides_fn = hf_overrides
else:
hf_overrides_kw = hf_overrides
hf_overrides_fn = identity
if rope_scaling is not None:
hf_override: Dict[str, Any] = {"rope_scaling": rope_scaling}
hf_overrides.update(hf_override)
hf_overrides_kw.update(hf_override)
msg = ("`--rope-scaling` will be removed in a future release. "
f"'Please instead use `--hf-overrides '{hf_override!r}'`")
warnings.warn(DeprecationWarning(msg), stacklevel=2)
if rope_theta is not None:
hf_override = {"rope_theta": rope_theta}
hf_overrides.update(hf_override)
hf_overrides_kw.update(hf_override)
msg = ("`--rope-theta` will be removed in a future release. "
f"'Please instead use `--hf-overrides '{hf_override!r}'`")
warnings.warn(DeprecationWarning(msg), stacklevel=2)
@ -207,9 +220,12 @@ class ModelConfig:
self.max_logprobs = max_logprobs
self.disable_sliding_window = disable_sliding_window
self.skip_tokenizer_init = skip_tokenizer_init
self.hf_config = get_config(self.model, trust_remote_code, revision,
code_revision, config_format,
**hf_overrides)
hf_config = get_config(self.model, trust_remote_code, revision,
code_revision, config_format, **hf_overrides_kw)
hf_config = hf_overrides_fn(hf_config)
self.hf_config = hf_config
self.hf_text_config = get_hf_text_config(self.hf_config)
self.encoder_config = self._get_encoder_config()
self.hf_image_processor_config = get_hf_image_processor_config(

View File

@ -9,9 +9,9 @@ import torch
import vllm.envs as envs
from vllm.config import (CacheConfig, ConfigFormat, DecodingConfig,
DeviceConfig, LoadConfig, LoadFormat, LoRAConfig,
ModelConfig, ObservabilityConfig, ParallelConfig,
PromptAdapterConfig, SchedulerConfig,
DeviceConfig, HfOverrides, LoadConfig, LoadFormat,
LoRAConfig, ModelConfig, ObservabilityConfig,
ParallelConfig, PromptAdapterConfig, SchedulerConfig,
SpeculativeConfig, TaskOption, TokenizerPoolConfig,
VllmConfig)
from vllm.executor.executor_base import ExecutorBase
@ -128,7 +128,7 @@ class EngineArgs:
code_revision: Optional[str] = None
rope_scaling: Optional[Dict[str, Any]] = None
rope_theta: Optional[float] = None
hf_overrides: Optional[Dict[str, Any]] = None
hf_overrides: Optional[HfOverrides] = None
tokenizer_revision: Optional[str] = None
quantization: Optional[str] = None
enforce_eager: Optional[bool] = None

View File

@ -9,7 +9,7 @@ from tqdm import tqdm
from vllm import envs
from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput,
BeamSearchSequence, get_beam_search_score)
from vllm.engine.arg_utils import EngineArgs, TaskOption
from vllm.engine.arg_utils import EngineArgs, HfOverrides, TaskOption
from vllm.engine.llm_engine import LLMEngine
from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
apply_hf_chat_template,
@ -101,7 +101,9 @@ class LLM:
disable_custom_all_reduce: See :class:`~vllm.config.ParallelConfig`
disable_async_output_proc: Disable async output processing.
This may result in lower performance.
hf_overrides: Arguments to be forwarded to the HuggingFace config.
hf_overrides: If a dictionary, contains arguments to be forwarded to the
HuggingFace config. If a callable, it is called to update the
HuggingFace config.
**kwargs: Arguments for :class:`~vllm.EngineArgs`. (See
:ref:`engine_args`)
@ -156,7 +158,7 @@ class LLM:
max_seq_len_to_capture: int = 8192,
disable_custom_all_reduce: bool = False,
disable_async_output_proc: bool = False,
hf_overrides: Optional[dict] = None,
hf_overrides: Optional[HfOverrides] = None,
mm_processor_kwargs: Optional[Dict[str, Any]] = None,
# After positional args are removed, move this right below `model`
task: TaskOption = "auto",

View File

@ -41,7 +41,8 @@ from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
from vllm.utils import is_list_of
from .interfaces import SupportsMultiModal, SupportsPP
from .utils import AutoWeightsLoader, flatten_bn, merge_multimodal_embeddings
from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix,
merge_multimodal_embeddings)
# Cannot find the following 2 numbers from hf config.
_IMAGE_TOKEN_ID = 71011
@ -245,7 +246,9 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
gather_output=True,
)
self.language_model = PersimmonForCausalLM(
vllm_config.with_hf_config(config.text_config))
vllm_config=vllm_config.with_hf_config(config.text_config),
prefix=maybe_prefix(prefix, "language_model"),
)
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors)

View File

@ -161,11 +161,5 @@ class InternLM2VEForCausalLM(InternLM2ForCausalLM):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__(vllm_config=vllm_config, prefix=prefix)
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.model = InternLM2VEModel(config,
cache_config,
quant_config,
self.model = InternLM2VEModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))

View File

@ -382,11 +382,7 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
instantiated.
"""
def __init__(
self,
vllm_config: VllmConfig,
prefix: str = "",
):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
config = vllm_config.model_config.hf_config
multimodal_config = vllm_config.model_config.multimodal_config
quant_config = vllm_config.quant_config
@ -699,12 +695,8 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
class MiniCPMV2_0(MiniCPMVBaseModel):
def __init__(
self,
vllm_config: VllmConfig,
prefix: str = "",
):
super().__init__(vllm_config)
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__(vllm_config=vllm_config, prefix=prefix)
assert self.version == (2, 0)
def init_llm(
@ -857,12 +849,8 @@ class MiniCPMV2_5(MiniCPMVBaseModel, SupportsLoRA):
embedding_modules = {}
embedding_padding_modules = []
def __init__(
self,
vllm_config: VllmConfig,
prefix: str = "",
):
super().__init__(vllm_config)
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__(vllm_config=vllm_config, prefix=prefix)
assert self.version == (2, 5)
def init_llm(
@ -999,12 +987,8 @@ class MiniCPMV2_6(MiniCPMVBaseModel, SupportsLoRA):
embedding_modules = {}
embedding_padding_modules = []
def __init__(
self,
vllm_config: VllmConfig,
prefix: str = "",
):
super().__init__(vllm_config)
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__(vllm_config=vllm_config, prefix=prefix)
assert self.version == (2, 6)
def init_llm(
@ -1117,7 +1101,7 @@ class MiniCPMV(MiniCPMVBaseModel, SupportsLoRA):
embedding_modules = {}
embedding_padding_modules = []
def __new__(cls, vllm_config: VllmConfig, prefix: str = ""):
def __new__(cls, *, vllm_config: VllmConfig, prefix: str = ""):
config = vllm_config.model_config.hf_config
if not hasattr(config, "version"):
if config.hidden_size == 2304 and config.query_num == 64:

View File

@ -65,7 +65,7 @@ class MLPSpeculator(nn.Module):
https://huggingface.co/ibm-fms and https://huggingface.co/ibm-granite
"""
def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None:
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
super().__init__()
config = vllm_config.model_config.hf_config
self.n_predict = config.n_predict

View File

@ -1,3 +1,7 @@
"""
Whenever you add an architecture to this page, please also update
`tests/models/registry.py` with example HuggingFace models for it.
"""
import importlib
import os
import pickle
@ -58,14 +62,14 @@ _TEXT_GENERATION_MODELS = {
"LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
"MambaForCausalLM": ("mamba", "MambaForCausalLM"),
"FalconMambaForCausalLM": ("mamba", "MambaForCausalLM"),
"MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"),
"MiniCPM3ForCausalLM": ("minicpm3", "MiniCPM3ForCausalLM"),
"MistralForCausalLM": ("llama", "LlamaForCausalLM"),
"MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"),
"QuantMixtralForCausalLM": ("mixtral_quant", "MixtralForCausalLM"),
# transformers's mpt class has lower case
"MptForCausalLM": ("mpt", "MPTForCausalLM"),
"MPTForCausalLM": ("mpt", "MPTForCausalLM"),
"MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"),
"MiniCPM3ForCausalLM": ("minicpm3", "MiniCPM3ForCausalLM"),
"NemotronForCausalLM": ("nemotron", "NemotronForCausalLM"),
"OlmoForCausalLM": ("olmo", "OlmoForCausalLM"),
"OlmoeForCausalLM": ("olmoe", "OlmoeForCausalLM"),