mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 00:06:06 +08:00
[Bugfix] Fix unable to load some models (#10312)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
11cd1ae6ad
commit
972112d82f
@ -313,14 +313,15 @@ steps:
|
||||
|
||||
##### models test #####
|
||||
|
||||
- label: Basic Models Test # 10min
|
||||
- label: Basic Models Test # 30min
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/models
|
||||
commands:
|
||||
- pip install -e ./plugins/vllm_add_dummy_model
|
||||
- pytest -v -s models/test_oot_registration.py # it needs a clean process
|
||||
- pytest -v -s models/*.py --ignore=models/test_oot_registration.py
|
||||
- pytest -v -s models/test_registry.py
|
||||
- pytest -v -s models/test_initialization.py
|
||||
|
||||
- label: Decoder-only Language Models Test (Standard) # 18min
|
||||
#mirror_hardwares: [amd]
|
||||
|
||||
@ -166,14 +166,14 @@ TEXT_GENERATION_MODELS = {
|
||||
"mistralai/Mixtral-8x7B-Instruct-v0.1": PPTestSettings.fast(tp_base=4),
|
||||
"mosaicml/mpt-7b": PPTestSettings.fast(),
|
||||
"nvidia/Minitron-8B-Base": PPTestSettings.fast(),
|
||||
"allenai/OLMoE-1B-7B-0924-Instruct": PPTestSettings.fast(),
|
||||
"allenai/OLMo-1B-hf": PPTestSettings.fast(),
|
||||
"allenai/OLMoE-1B-7B-0924-Instruct": PPTestSettings.fast(),
|
||||
"facebook/opt-iml-max-1.3b": PPTestSettings.fast(),
|
||||
"OrionStarAI/Orion-14B-Chat": PPTestSettings.fast(trust_remote_code=True),
|
||||
"microsoft/phi-2": PPTestSettings.fast(),
|
||||
"microsoft/Phi-3.5-MoE-instruct": PPTestSettings.detailed(trust_remote_code=True, multi_node_only=True, load_format="dummy", hf_overrides='{"num_hidden_layers": 4, "hidden_size": 512, "intermediate_size": 800, "num_attention_heads": 4, "num_key_value_heads": 1}'), # noqa: E501
|
||||
"microsoft/Phi-3-small-8k-instruct": PPTestSettings.fast(trust_remote_code=True), # noqa: E501
|
||||
"adept/persimmon-8b-chat": PPTestSettings.fast(),
|
||||
"microsoft/phi-2": PPTestSettings.fast(),
|
||||
"microsoft/Phi-3-small-8k-instruct": PPTestSettings.fast(trust_remote_code=True), # noqa: E501
|
||||
"microsoft/Phi-3.5-MoE-instruct": PPTestSettings.detailed(trust_remote_code=True, multi_node_only=True, load_format="dummy", hf_overrides='{"num_hidden_layers": 4, "hidden_size": 512, "intermediate_size": 800, "num_attention_heads": 4, "num_key_value_heads": 1}'), # noqa: E501
|
||||
"Qwen/Qwen-7B-Chat": PPTestSettings.fast(trust_remote_code=True),
|
||||
"Qwen/Qwen2-7B-Instruct": PPTestSettings.fast(),
|
||||
"Qwen/Qwen1.5-MoE-A2.7B-Chat": PPTestSettings.fast(),
|
||||
|
||||
212
tests/models/registry.py
Normal file
212
tests/models/registry.py
Normal file
@ -0,0 +1,212 @@
|
||||
from dataclasses import dataclass, field
|
||||
from typing import AbstractSet, Mapping, Optional
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class _HfExamplesInfo:
|
||||
default: str
|
||||
"""The default model to use for testing this architecture."""
|
||||
|
||||
extras: Mapping[str, str] = field(default_factory=dict)
|
||||
"""Extra models to use for testing this architecture."""
|
||||
|
||||
tokenizer: Optional[str] = None
|
||||
"""Set the tokenizer to load for this architecture."""
|
||||
|
||||
tokenizer_mode: str = "auto"
|
||||
"""Set the tokenizer type for this architecture."""
|
||||
|
||||
speculative_model: Optional[str] = None
|
||||
"""
|
||||
The default model to use for testing this architecture, which is only used
|
||||
for speculative decoding.
|
||||
"""
|
||||
|
||||
is_available_online: bool = True
|
||||
"""
|
||||
Set this to ``False`` if the name of this architecture no longer exists on
|
||||
the HF repo. To maintain backwards compatibility, we have not removed them
|
||||
from the main model registry, so without this flag the registry tests will
|
||||
fail.
|
||||
"""
|
||||
|
||||
trust_remote_code: bool = False
|
||||
"""The ``trust_remote_code`` level required to load the model."""
|
||||
|
||||
|
||||
# yapf: disable
|
||||
_TEXT_GENERATION_EXAMPLE_MODELS = {
|
||||
# [Decoder-only]
|
||||
"AquilaModel": _HfExamplesInfo("BAAI/AquilaChat-7B",
|
||||
trust_remote_code=True),
|
||||
"AquilaForCausalLM": _HfExamplesInfo("BAAI/AquilaChat2-7B",
|
||||
trust_remote_code=True),
|
||||
"ArcticForCausalLM": _HfExamplesInfo("Snowflake/snowflake-arctic-instruct",
|
||||
trust_remote_code=True),
|
||||
"BaiChuanForCausalLM": _HfExamplesInfo("baichuan-inc/Baichuan-7B",
|
||||
trust_remote_code=True),
|
||||
"BaichuanForCausalLM": _HfExamplesInfo("baichuan-inc/Baichuan2-7B-chat",
|
||||
trust_remote_code=True),
|
||||
"BloomForCausalLM": _HfExamplesInfo("bigscience/bloomz-1b1"),
|
||||
# ChatGLMModel supports multimodal
|
||||
"CohereForCausalLM": _HfExamplesInfo("CohereForAI/c4ai-command-r-v01",
|
||||
trust_remote_code=True),
|
||||
"DbrxForCausalLM": _HfExamplesInfo("databricks/dbrx-instruct"),
|
||||
"DeciLMForCausalLM": _HfExamplesInfo("Deci/DeciLM-7B-instruct",
|
||||
trust_remote_code=True),
|
||||
"DeepseekForCausalLM": _HfExamplesInfo("deepseek-ai/deepseek-llm-7b-chat"),
|
||||
"DeepseekV2ForCausalLM": _HfExamplesInfo("deepseek-ai/DeepSeek-V2-Lite-Chat", # noqa: E501
|
||||
trust_remote_code=True),
|
||||
"ExaoneForCausalLM": _HfExamplesInfo("LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct"), # noqa: E501
|
||||
"FalconForCausalLM": _HfExamplesInfo("tiiuae/falcon-7b"),
|
||||
"GemmaForCausalLM": _HfExamplesInfo("google/gemma-2b"),
|
||||
"Gemma2ForCausalLM": _HfExamplesInfo("google/gemma-2-9b"),
|
||||
"GPT2LMHeadModel": _HfExamplesInfo("gpt2"),
|
||||
"GPTBigCodeForCausalLM": _HfExamplesInfo("bigcode/starcoder"),
|
||||
"GPTJForCausalLM": _HfExamplesInfo("EleutherAI/gpt-j-6b"),
|
||||
"GPTNeoXForCausalLM": _HfExamplesInfo("EleutherAI/pythia-160m"),
|
||||
"GraniteForCausalLM": _HfExamplesInfo("ibm/PowerLM-3b"),
|
||||
"GraniteMoeForCausalLM": _HfExamplesInfo("ibm/PowerMoE-3b"),
|
||||
"InternLMForCausalLM": _HfExamplesInfo("internlm/internlm-chat-7b",
|
||||
trust_remote_code=True),
|
||||
"InternLM2ForCausalLM": _HfExamplesInfo("internlm/internlm2-chat-7b",
|
||||
trust_remote_code=True),
|
||||
"InternLM2VEForCausalLM": _HfExamplesInfo("OpenGVLab/Mono-InternVL-2B",
|
||||
trust_remote_code=True),
|
||||
"JAISLMHeadModel": _HfExamplesInfo("inceptionai/jais-13b-chat"),
|
||||
"JambaForCausalLM": _HfExamplesInfo("ai21labs/AI21-Jamba-1.5-Mini"),
|
||||
"LlamaForCausalLM": _HfExamplesInfo("meta-llama/Meta-Llama-3-8B"),
|
||||
"LLaMAForCausalLM": _HfExamplesInfo("decapoda-research/llama-7b-hf",
|
||||
is_available_online=False),
|
||||
"MambaForCausalLM": _HfExamplesInfo("state-spaces/mamba-130m-hf"),
|
||||
"FalconMambaForCausalLM": _HfExamplesInfo("tiiuae/falcon-mamba-7b-instruct"), # noqa: E501
|
||||
"MiniCPMForCausalLM": _HfExamplesInfo("openbmb/MiniCPM-2B-sft-bf16",
|
||||
trust_remote_code=True),
|
||||
"MiniCPM3ForCausalLM": _HfExamplesInfo("openbmb/MiniCPM3-4B",
|
||||
trust_remote_code=True),
|
||||
"MistralForCausalLM": _HfExamplesInfo("mistralai/Mistral-7B-Instruct-v0.1"),
|
||||
"MixtralForCausalLM": _HfExamplesInfo("mistralai/Mixtral-8x7B-Instruct-v0.1"), # noqa: E501
|
||||
"QuantMixtralForCausalLM": _HfExamplesInfo("mistral-community/Mixtral-8x22B-v0.1-AWQ"), # noqa: E501
|
||||
"MptForCausalLM": _HfExamplesInfo("mpt", is_available_online=False),
|
||||
"MPTForCausalLM": _HfExamplesInfo("mosaicml/mpt-7b"),
|
||||
"NemotronForCausalLM": _HfExamplesInfo("nvidia/Minitron-8B-Base"),
|
||||
"OlmoForCausalLM": _HfExamplesInfo("allenai/OLMo-1B-hf"),
|
||||
"OlmoeForCausalLM": _HfExamplesInfo("allenai/OLMoE-1B-7B-0924-Instruct"),
|
||||
"OPTForCausalLM": _HfExamplesInfo("facebook/opt-iml-max-1.3b"),
|
||||
"OrionForCausalLM": _HfExamplesInfo("OrionStarAI/Orion-14B-Chat",
|
||||
trust_remote_code=True),
|
||||
"PersimmonForCausalLM": _HfExamplesInfo("adept/persimmon-8b-chat"),
|
||||
"PhiForCausalLM": _HfExamplesInfo("microsoft/phi-2"),
|
||||
"Phi3ForCausalLM": _HfExamplesInfo("microsoft/Phi-3-mini-4k-instruct"),
|
||||
"Phi3SmallForCausalLM": _HfExamplesInfo("microsoft/Phi-3-small-8k-instruct",
|
||||
trust_remote_code=True),
|
||||
"PhiMoEForCausalLM": _HfExamplesInfo("microsoft/Phi-3.5-MoE-instruct",
|
||||
trust_remote_code=True),
|
||||
# QWenLMHeadModel supports multimodal
|
||||
"Qwen2ForCausalLM": _HfExamplesInfo("Qwen/Qwen2-7B-Instruct"),
|
||||
"Qwen2MoeForCausalLM": _HfExamplesInfo("Qwen/Qwen1.5-MoE-A2.7B-Chat"),
|
||||
"RWForCausalLM": _HfExamplesInfo("tiiuae/falcon-40b",
|
||||
is_available_online=False),
|
||||
"StableLMEpochForCausalLM": _HfExamplesInfo("stabilityai/stablelm-zephyr-3b", # noqa: E501
|
||||
is_available_online=False),
|
||||
"StableLmForCausalLM": _HfExamplesInfo("stabilityai/stablelm-3b-4e1t"),
|
||||
"Starcoder2ForCausalLM": _HfExamplesInfo("bigcode/starcoder2-3b"),
|
||||
"SolarForCausalLM": _HfExamplesInfo("upstage/solar-pro-preview-instruct"),
|
||||
"XverseForCausalLM": _HfExamplesInfo("xverse/XVERSE-7B-Chat",
|
||||
is_available_online=False,
|
||||
trust_remote_code=True),
|
||||
# [Encoder-decoder]
|
||||
"BartModel": _HfExamplesInfo("facebook/bart-base"),
|
||||
"BartForConditionalGeneration": _HfExamplesInfo("facebook/bart-large-cnn"),
|
||||
# Florence-2 uses BartFastTokenizer which can't be loaded from AutoTokenizer
|
||||
# Therefore, we borrow the BartTokenizer from the original Bart model
|
||||
"Florence2ForConditionalGeneration": _HfExamplesInfo("microsoft/Florence-2-base", # noqa: E501
|
||||
tokenizer="facebook/bart-base",
|
||||
trust_remote_code=True), # noqa: E501
|
||||
}
|
||||
|
||||
_EMBEDDING_EXAMPLE_MODELS = {
|
||||
# [Text-only]
|
||||
"BertModel": _HfExamplesInfo("BAAI/bge-base-en-v1.5"),
|
||||
"Gemma2Model": _HfExamplesInfo("BAAI/bge-multilingual-gemma2"),
|
||||
"MistralModel": _HfExamplesInfo("intfloat/e5-mistral-7b-instruct"),
|
||||
"Qwen2ForRewardModel": _HfExamplesInfo("Qwen/Qwen2.5-Math-RM-72B"),
|
||||
"Qwen2ForSequenceClassification": _HfExamplesInfo("jason9693/Qwen2.5-1.5B-apeach"), # noqa: E501
|
||||
# [Multimodal]
|
||||
"LlavaNextForConditionalGeneration": _HfExamplesInfo("royokong/e5-v"),
|
||||
"Phi3VForCausalLM": _HfExamplesInfo("TIGER-Lab/VLM2Vec-Full",
|
||||
trust_remote_code=True),
|
||||
"Qwen2VLForConditionalGeneration": _HfExamplesInfo("MrLight/dse-qwen2-2b-mrl-v1"), # noqa: E501
|
||||
}
|
||||
|
||||
_MULTIMODAL_EXAMPLE_MODELS = {
|
||||
# [Decoder-only]
|
||||
"Blip2ForConditionalGeneration": _HfExamplesInfo("Salesforce/blip2-opt-2.7b"), # noqa: E501
|
||||
"ChameleonForConditionalGeneration": _HfExamplesInfo("facebook/chameleon-7b"), # noqa: E501
|
||||
"ChatGLMModel": _HfExamplesInfo("THUDM/glm-4v-9b",
|
||||
extras={"text_only": "THUDM/chatglm3-6b"},
|
||||
trust_remote_code=True),
|
||||
"ChatGLMForConditionalGeneration": _HfExamplesInfo("chatglm2-6b",
|
||||
is_available_online=False),
|
||||
"FuyuForCausalLM": _HfExamplesInfo("adept/fuyu-8b"),
|
||||
"H2OVLChatModel": _HfExamplesInfo("h2oai/h2ovl-mississippi-800m"),
|
||||
"InternVLChatModel": _HfExamplesInfo("OpenGVLab/InternVL2-1B",
|
||||
trust_remote_code=True),
|
||||
"Idefics3ForConditionalGeneration": _HfExamplesInfo("HuggingFaceM4/Idefics3-8B-Llama3"), # noqa: E501
|
||||
"LlavaForConditionalGeneration": _HfExamplesInfo("llava-hf/llava-1.5-7b-hf",
|
||||
extras={"mistral": "mistral-community/pixtral-12b"}), # noqa: E501
|
||||
"LlavaNextForConditionalGeneration": _HfExamplesInfo("llava-hf/llava-v1.6-mistral-7b-hf"), # noqa: E501
|
||||
"LlavaNextVideoForConditionalGeneration": _HfExamplesInfo("llava-hf/LLaVA-NeXT-Video-7B-hf"), # noqa: E501
|
||||
"LlavaOnevisionForConditionalGeneration": _HfExamplesInfo("llava-hf/llava-onevision-qwen2-0.5b-ov-hf"), # noqa: E501
|
||||
"MiniCPMV": _HfExamplesInfo("openbmb/MiniCPM-Llama3-V-2_5",
|
||||
trust_remote_code=True),
|
||||
"MolmoForCausalLM": _HfExamplesInfo("allenai/Molmo-7B-D-0924",
|
||||
trust_remote_code=True),
|
||||
"NVLM_D": _HfExamplesInfo("nvidia/NVLM-D-72B",
|
||||
trust_remote_code=True),
|
||||
"PaliGemmaForConditionalGeneration": _HfExamplesInfo("google/paligemma-3b-pt-224"), # noqa: E501
|
||||
"Phi3VForCausalLM": _HfExamplesInfo("microsoft/Phi-3-vision-128k-instruct",
|
||||
trust_remote_code=True),
|
||||
"PixtralForConditionalGeneration": _HfExamplesInfo("mistralai/Pixtral-12B-2409", # noqa: E501
|
||||
tokenizer_mode="mistral"),
|
||||
"QWenLMHeadModel": _HfExamplesInfo("Qwen/Qwen-VL-Chat",
|
||||
extras={"text_only": "Qwen/Qwen-7B-Chat"}, # noqa: E501
|
||||
trust_remote_code=True),
|
||||
"Qwen2AudioForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2-Audio-7B-Instruct"), # noqa: E501
|
||||
"Qwen2VLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2-VL-2B-Instruct"), # noqa: E501
|
||||
"UltravoxModel": _HfExamplesInfo("fixie-ai/ultravox-v0_3"),
|
||||
# [Encoder-decoder]
|
||||
"MllamaForConditionalGeneration": _HfExamplesInfo("meta-llama/Llama-3.2-11B-Vision-Instruct"), # noqa: E501
|
||||
}
|
||||
|
||||
_SPECULATIVE_DECODING_EXAMPLE_MODELS = {
|
||||
"EAGLEModel": _HfExamplesInfo("JackFram/llama-68m",
|
||||
speculative_model="abhigoyal/vllm-eagle-llama-68m-random"), # noqa: E501
|
||||
"MedusaModel": _HfExamplesInfo("JackFram/llama-68m",
|
||||
speculative_model="abhigoyal/vllm-medusa-llama-68m-random"), # noqa: E501
|
||||
"MLPSpeculatorPreTrainedModel": _HfExamplesInfo("JackFram/llama-160m",
|
||||
speculative_model="ibm-fms/llama-160m-accelerator"), # noqa: E501
|
||||
}
|
||||
|
||||
_EXAMPLE_MODELS = {
|
||||
**_TEXT_GENERATION_EXAMPLE_MODELS,
|
||||
**_EMBEDDING_EXAMPLE_MODELS,
|
||||
**_MULTIMODAL_EXAMPLE_MODELS,
|
||||
**_SPECULATIVE_DECODING_EXAMPLE_MODELS,
|
||||
}
|
||||
|
||||
|
||||
class HfExampleModels:
|
||||
def __init__(self, hf_models: Mapping[str, _HfExamplesInfo]) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.hf_models = hf_models
|
||||
|
||||
def get_supported_archs(self) -> AbstractSet[str]:
|
||||
return self.hf_models.keys()
|
||||
|
||||
def get_hf_info(self, model_arch: str) -> _HfExamplesInfo:
|
||||
return self.hf_models[model_arch]
|
||||
|
||||
|
||||
HF_EXAMPLE_MODELS = HfExampleModels(_EXAMPLE_MODELS)
|
||||
55
tests/models/test_initialization.py
Normal file
55
tests/models/test_initialization.py
Normal file
@ -0,0 +1,55 @@
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import transformers
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm import LLM
|
||||
|
||||
from .registry import HF_EXAMPLE_MODELS
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_arch", HF_EXAMPLE_MODELS.get_supported_archs())
|
||||
def test_can_initialize(model_arch):
|
||||
if (model_arch == "Idefics3ForConditionalGeneration"
|
||||
and transformers.__version__ < "4.46.0"):
|
||||
pytest.skip(reason="Model introduced in HF >= 4.46.0")
|
||||
|
||||
model_info = HF_EXAMPLE_MODELS.get_hf_info(model_arch)
|
||||
if not model_info.is_available_online:
|
||||
pytest.skip("Model is not available online")
|
||||
|
||||
# Avoid OOM
|
||||
def hf_overrides(hf_config: PretrainedConfig) -> PretrainedConfig:
|
||||
if hasattr(hf_config, "text_config"):
|
||||
text_config: PretrainedConfig = hf_config.text_config
|
||||
else:
|
||||
text_config = hf_config
|
||||
|
||||
text_config.update({
|
||||
"num_layers": 1,
|
||||
"num_hidden_layers": 1,
|
||||
"num_experts": 2,
|
||||
"num_experts_per_tok": 2,
|
||||
"num_local_experts": 2,
|
||||
})
|
||||
|
||||
return hf_config
|
||||
|
||||
# Avoid calling model.forward()
|
||||
def _initialize_kv_caches(self) -> None:
|
||||
self.cache_config.num_gpu_blocks = 0
|
||||
self.cache_config.num_cpu_blocks = 0
|
||||
|
||||
with patch.object(LLM.get_engine_class(), "_initialize_kv_caches",
|
||||
_initialize_kv_caches):
|
||||
LLM(
|
||||
model_info.default,
|
||||
tokenizer=model_info.tokenizer,
|
||||
tokenizer_mode=model_info.tokenizer_mode,
|
||||
speculative_model=model_info.speculative_model,
|
||||
num_speculative_tokens=1 if model_info.speculative_model else None,
|
||||
trust_remote_code=model_info.trust_remote_code,
|
||||
load_format="dummy",
|
||||
hf_overrides=hf_overrides,
|
||||
)
|
||||
@ -14,6 +14,7 @@ from vllm.model_executor.models.registry import (_EMBEDDING_MODELS,
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from ..utils import fork_new_process_for_each_test
|
||||
from .registry import HF_EXAMPLE_MODELS
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_arch", ModelRegistry.get_supported_archs())
|
||||
@ -73,3 +74,12 @@ def test_registry_is_pp(model_arch, is_pp, init_cuda):
|
||||
"This model no longer initializes CUDA on import. "
|
||||
"Please test using a different one.",
|
||||
stacklevel=2)
|
||||
|
||||
|
||||
def test_hf_registry_coverage():
|
||||
untested_archs = (HF_EXAMPLE_MODELS.get_supported_archs() -
|
||||
set(ModelRegistry.get_supported_archs()))
|
||||
|
||||
assert not untested_archs, (
|
||||
"Please add the following architectures to "
|
||||
f"`tests/models/registry.py`: {untested_archs}")
|
||||
|
||||
@ -3,8 +3,8 @@ import enum
|
||||
import json
|
||||
import warnings
|
||||
from dataclasses import dataclass, field, replace
|
||||
from typing import (TYPE_CHECKING, Any, ClassVar, Dict, Final, List, Literal,
|
||||
Mapping, Optional, Set, Tuple, Type, Union)
|
||||
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Dict, Final, List,
|
||||
Literal, Mapping, Optional, Set, Tuple, Type, Union)
|
||||
|
||||
import torch
|
||||
from transformers import PretrainedConfig
|
||||
@ -20,7 +20,7 @@ from vllm.transformers_utils.config import (
|
||||
get_hf_text_config, get_pooling_config,
|
||||
get_sentence_transformer_tokenizer_config, is_encoder_decoder, uses_mrope)
|
||||
from vllm.utils import (GiB_bytes, cuda_device_count_stateless, get_cpu_memory,
|
||||
print_warning_once)
|
||||
identity, print_warning_once)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ray.util.placement_group import PlacementGroup
|
||||
@ -44,6 +44,9 @@ TaskOption = Literal["auto", "generate", "embedding"]
|
||||
# "draft" is only used internally for speculative decoding
|
||||
_Task = Literal["generate", "embedding", "draft"]
|
||||
|
||||
HfOverrides = Union[Dict[str, Any], Callable[[PretrainedConfig],
|
||||
PretrainedConfig]]
|
||||
|
||||
|
||||
class ModelConfig:
|
||||
"""Configuration for the model.
|
||||
@ -115,7 +118,9 @@ class ModelConfig:
|
||||
can not be gathered from the vllm arguments.
|
||||
config_format: The config format which shall be loaded.
|
||||
Defaults to 'auto' which defaults to 'hf'.
|
||||
hf_overrides: Arguments to be forwarded to the HuggingFace config.
|
||||
hf_overrides: If a dictionary, contains arguments to be forwarded to the
|
||||
HuggingFace config. If a callable, it is called to update the
|
||||
HuggingFace config.
|
||||
mm_processor_kwargs: Arguments to be forwarded to the model's processor
|
||||
for multi-modal data, e.g., image processor.
|
||||
pooling_type: Used to configure the pooling method in the embedding
|
||||
@ -164,7 +169,7 @@ class ModelConfig:
|
||||
override_neuron_config: Optional[Dict[str, Any]] = None,
|
||||
config_format: ConfigFormat = ConfigFormat.AUTO,
|
||||
chat_template_text_format: str = "string",
|
||||
hf_overrides: Optional[Dict[str, Any]] = None,
|
||||
hf_overrides: Optional[HfOverrides] = None,
|
||||
mm_processor_kwargs: Optional[Dict[str, Any]] = None,
|
||||
pooling_type: Optional[str] = None,
|
||||
pooling_norm: Optional[bool] = None,
|
||||
@ -182,15 +187,23 @@ class ModelConfig:
|
||||
|
||||
if hf_overrides is None:
|
||||
hf_overrides = {}
|
||||
|
||||
if callable(hf_overrides):
|
||||
hf_overrides_kw = {}
|
||||
hf_overrides_fn = hf_overrides
|
||||
else:
|
||||
hf_overrides_kw = hf_overrides
|
||||
hf_overrides_fn = identity
|
||||
|
||||
if rope_scaling is not None:
|
||||
hf_override: Dict[str, Any] = {"rope_scaling": rope_scaling}
|
||||
hf_overrides.update(hf_override)
|
||||
hf_overrides_kw.update(hf_override)
|
||||
msg = ("`--rope-scaling` will be removed in a future release. "
|
||||
f"'Please instead use `--hf-overrides '{hf_override!r}'`")
|
||||
warnings.warn(DeprecationWarning(msg), stacklevel=2)
|
||||
if rope_theta is not None:
|
||||
hf_override = {"rope_theta": rope_theta}
|
||||
hf_overrides.update(hf_override)
|
||||
hf_overrides_kw.update(hf_override)
|
||||
msg = ("`--rope-theta` will be removed in a future release. "
|
||||
f"'Please instead use `--hf-overrides '{hf_override!r}'`")
|
||||
warnings.warn(DeprecationWarning(msg), stacklevel=2)
|
||||
@ -207,9 +220,12 @@ class ModelConfig:
|
||||
self.max_logprobs = max_logprobs
|
||||
self.disable_sliding_window = disable_sliding_window
|
||||
self.skip_tokenizer_init = skip_tokenizer_init
|
||||
self.hf_config = get_config(self.model, trust_remote_code, revision,
|
||||
code_revision, config_format,
|
||||
**hf_overrides)
|
||||
|
||||
hf_config = get_config(self.model, trust_remote_code, revision,
|
||||
code_revision, config_format, **hf_overrides_kw)
|
||||
hf_config = hf_overrides_fn(hf_config)
|
||||
self.hf_config = hf_config
|
||||
|
||||
self.hf_text_config = get_hf_text_config(self.hf_config)
|
||||
self.encoder_config = self._get_encoder_config()
|
||||
self.hf_image_processor_config = get_hf_image_processor_config(
|
||||
|
||||
@ -9,9 +9,9 @@ import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import (CacheConfig, ConfigFormat, DecodingConfig,
|
||||
DeviceConfig, LoadConfig, LoadFormat, LoRAConfig,
|
||||
ModelConfig, ObservabilityConfig, ParallelConfig,
|
||||
PromptAdapterConfig, SchedulerConfig,
|
||||
DeviceConfig, HfOverrides, LoadConfig, LoadFormat,
|
||||
LoRAConfig, ModelConfig, ObservabilityConfig,
|
||||
ParallelConfig, PromptAdapterConfig, SchedulerConfig,
|
||||
SpeculativeConfig, TaskOption, TokenizerPoolConfig,
|
||||
VllmConfig)
|
||||
from vllm.executor.executor_base import ExecutorBase
|
||||
@ -128,7 +128,7 @@ class EngineArgs:
|
||||
code_revision: Optional[str] = None
|
||||
rope_scaling: Optional[Dict[str, Any]] = None
|
||||
rope_theta: Optional[float] = None
|
||||
hf_overrides: Optional[Dict[str, Any]] = None
|
||||
hf_overrides: Optional[HfOverrides] = None
|
||||
tokenizer_revision: Optional[str] = None
|
||||
quantization: Optional[str] = None
|
||||
enforce_eager: Optional[bool] = None
|
||||
|
||||
@ -9,7 +9,7 @@ from tqdm import tqdm
|
||||
from vllm import envs
|
||||
from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput,
|
||||
BeamSearchSequence, get_beam_search_score)
|
||||
from vllm.engine.arg_utils import EngineArgs, TaskOption
|
||||
from vllm.engine.arg_utils import EngineArgs, HfOverrides, TaskOption
|
||||
from vllm.engine.llm_engine import LLMEngine
|
||||
from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
|
||||
apply_hf_chat_template,
|
||||
@ -101,7 +101,9 @@ class LLM:
|
||||
disable_custom_all_reduce: See :class:`~vllm.config.ParallelConfig`
|
||||
disable_async_output_proc: Disable async output processing.
|
||||
This may result in lower performance.
|
||||
hf_overrides: Arguments to be forwarded to the HuggingFace config.
|
||||
hf_overrides: If a dictionary, contains arguments to be forwarded to the
|
||||
HuggingFace config. If a callable, it is called to update the
|
||||
HuggingFace config.
|
||||
**kwargs: Arguments for :class:`~vllm.EngineArgs`. (See
|
||||
:ref:`engine_args`)
|
||||
|
||||
@ -156,7 +158,7 @@ class LLM:
|
||||
max_seq_len_to_capture: int = 8192,
|
||||
disable_custom_all_reduce: bool = False,
|
||||
disable_async_output_proc: bool = False,
|
||||
hf_overrides: Optional[dict] = None,
|
||||
hf_overrides: Optional[HfOverrides] = None,
|
||||
mm_processor_kwargs: Optional[Dict[str, Any]] = None,
|
||||
# After positional args are removed, move this right below `model`
|
||||
task: TaskOption = "auto",
|
||||
|
||||
@ -41,7 +41,8 @@ from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
|
||||
from vllm.utils import is_list_of
|
||||
|
||||
from .interfaces import SupportsMultiModal, SupportsPP
|
||||
from .utils import AutoWeightsLoader, flatten_bn, merge_multimodal_embeddings
|
||||
from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix,
|
||||
merge_multimodal_embeddings)
|
||||
|
||||
# Cannot find the following 2 numbers from hf config.
|
||||
_IMAGE_TOKEN_ID = 71011
|
||||
@ -245,7 +246,9 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
gather_output=True,
|
||||
)
|
||||
self.language_model = PersimmonForCausalLM(
|
||||
vllm_config.with_hf_config(config.text_config))
|
||||
vllm_config=vllm_config.with_hf_config(config.text_config),
|
||||
prefix=maybe_prefix(prefix, "language_model"),
|
||||
)
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.language_model.make_empty_intermediate_tensors)
|
||||
|
||||
|
||||
@ -161,11 +161,5 @@ class InternLM2VEForCausalLM(InternLM2ForCausalLM):
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__(vllm_config=vllm_config, prefix=prefix)
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
|
||||
self.model = InternLM2VEModel(config,
|
||||
cache_config,
|
||||
quant_config,
|
||||
self.model = InternLM2VEModel(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "model"))
|
||||
|
||||
@ -382,11 +382,7 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
instantiated.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
):
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
config = vllm_config.model_config.hf_config
|
||||
multimodal_config = vllm_config.model_config.multimodal_config
|
||||
quant_config = vllm_config.quant_config
|
||||
@ -699,12 +695,8 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
|
||||
class MiniCPMV2_0(MiniCPMVBaseModel):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__(vllm_config)
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__(vllm_config=vllm_config, prefix=prefix)
|
||||
assert self.version == (2, 0)
|
||||
|
||||
def init_llm(
|
||||
@ -857,12 +849,8 @@ class MiniCPMV2_5(MiniCPMVBaseModel, SupportsLoRA):
|
||||
embedding_modules = {}
|
||||
embedding_padding_modules = []
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__(vllm_config)
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__(vllm_config=vllm_config, prefix=prefix)
|
||||
assert self.version == (2, 5)
|
||||
|
||||
def init_llm(
|
||||
@ -999,12 +987,8 @@ class MiniCPMV2_6(MiniCPMVBaseModel, SupportsLoRA):
|
||||
embedding_modules = {}
|
||||
embedding_padding_modules = []
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__(vllm_config)
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__(vllm_config=vllm_config, prefix=prefix)
|
||||
assert self.version == (2, 6)
|
||||
|
||||
def init_llm(
|
||||
@ -1117,7 +1101,7 @@ class MiniCPMV(MiniCPMVBaseModel, SupportsLoRA):
|
||||
embedding_modules = {}
|
||||
embedding_padding_modules = []
|
||||
|
||||
def __new__(cls, vllm_config: VllmConfig, prefix: str = ""):
|
||||
def __new__(cls, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
config = vllm_config.model_config.hf_config
|
||||
if not hasattr(config, "version"):
|
||||
if config.hidden_size == 2304 and config.query_num == 64:
|
||||
|
||||
@ -65,7 +65,7 @@ class MLPSpeculator(nn.Module):
|
||||
https://huggingface.co/ibm-fms and https://huggingface.co/ibm-granite
|
||||
"""
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None:
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
self.n_predict = config.n_predict
|
||||
|
||||
@ -1,3 +1,7 @@
|
||||
"""
|
||||
Whenever you add an architecture to this page, please also update
|
||||
`tests/models/registry.py` with example HuggingFace models for it.
|
||||
"""
|
||||
import importlib
|
||||
import os
|
||||
import pickle
|
||||
@ -58,14 +62,14 @@ _TEXT_GENERATION_MODELS = {
|
||||
"LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
|
||||
"MambaForCausalLM": ("mamba", "MambaForCausalLM"),
|
||||
"FalconMambaForCausalLM": ("mamba", "MambaForCausalLM"),
|
||||
"MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"),
|
||||
"MiniCPM3ForCausalLM": ("minicpm3", "MiniCPM3ForCausalLM"),
|
||||
"MistralForCausalLM": ("llama", "LlamaForCausalLM"),
|
||||
"MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"),
|
||||
"QuantMixtralForCausalLM": ("mixtral_quant", "MixtralForCausalLM"),
|
||||
# transformers's mpt class has lower case
|
||||
"MptForCausalLM": ("mpt", "MPTForCausalLM"),
|
||||
"MPTForCausalLM": ("mpt", "MPTForCausalLM"),
|
||||
"MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"),
|
||||
"MiniCPM3ForCausalLM": ("minicpm3", "MiniCPM3ForCausalLM"),
|
||||
"NemotronForCausalLM": ("nemotron", "NemotronForCausalLM"),
|
||||
"OlmoForCausalLM": ("olmo", "OlmoForCausalLM"),
|
||||
"OlmoeForCausalLM": ("olmoe", "OlmoeForCausalLM"),
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user