From 6c85da3a1859cbd4bc3cd76fc7210a33af077264 Mon Sep 17 00:00:00 2001 From: Roger Wang <136131678+ywang96@users.noreply.github.com> Date: Thu, 27 Feb 2025 17:02:15 -0800 Subject: [PATCH] [V1]`SupportsV0Only` protocol for model definitions (#13959) Signed-off-by: Roger Wang --- vllm/config.py | 5 ++++ vllm/model_executor/models/__init__.py | 7 +++-- vllm/model_executor/models/bamba.py | 5 ++-- vllm/model_executor/models/bart.py | 3 ++- vllm/model_executor/models/bert.py | 4 +-- vllm/model_executor/models/florence2.py | 4 +-- vllm/model_executor/models/gritlm.py | 4 ++- vllm/model_executor/models/interfaces.py | 26 +++++++++++++++++++ vllm/model_executor/models/jamba.py | 5 ++-- vllm/model_executor/models/mamba.py | 6 +++-- vllm/model_executor/models/mamba2.py | 6 +++-- vllm/model_executor/models/minicpmv.py | 6 +++-- vllm/model_executor/models/mllama.py | 5 ++-- vllm/model_executor/models/paligemma.py | 4 +-- .../models/prithvi_geospatial_mae.py | 6 +++-- vllm/model_executor/models/qwen2_rm.py | 5 ++-- vllm/model_executor/models/registry.py | 14 ++++++++-- vllm/model_executor/models/roberta.py | 5 ++-- vllm/model_executor/models/whisper.py | 5 ++-- 19 files changed, 93 insertions(+), 32 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index c3f9932ab8b3f..78d02b0173503 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1039,6 +1039,11 @@ class ModelConfig: def runner_type(self) -> RunnerType: return _TASK_RUNNER[self.task] + @property + def is_v1_compatible(self) -> bool: + architectures = getattr(self.hf_config, "architectures", []) + return ModelRegistry.is_v1_compatible(architectures) + class CacheConfig: """Configuration for the KV cache. diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index 6be4a8341306e..3580c4fa52525 100644 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -1,8 +1,9 @@ # SPDX-License-Identifier: Apache-2.0 from .interfaces import (HasInnerState, SupportsLoRA, SupportsMultiModal, - SupportsPP, has_inner_state, supports_lora, - supports_multimodal, supports_pp) + SupportsPP, SupportsV0Only, has_inner_state, + supports_lora, supports_multimodal, supports_pp, + supports_v0_only) from .interfaces_base import (VllmModelForPooling, VllmModelForTextGeneration, is_pooling_model, is_text_generation_model) from .registry import ModelRegistry @@ -21,4 +22,6 @@ __all__ = [ "supports_multimodal", "SupportsPP", "supports_pp", + "SupportsV0Only", + "supports_v0_only", ] diff --git a/vllm/model_executor/models/bamba.py b/vllm/model_executor/models/bamba.py index 69da05884ded8..ec62e41d59f0f 100644 --- a/vllm/model_executor/models/bamba.py +++ b/vllm/model_executor/models/bamba.py @@ -32,7 +32,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from vllm.utils import LayerBlockType -from .interfaces import HasInnerState, IsHybrid, SupportsLoRA, SupportsPP +from .interfaces import (HasInnerState, IsHybrid, SupportsLoRA, SupportsPP, + SupportsV0Only) from .utils import (is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) @@ -366,7 +367,7 @@ class BambaModel(nn.Module): class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, - IsHybrid): + IsHybrid, SupportsV0Only): packed_modules_mapping = { "qkv_proj": [ "q_proj", diff --git a/vllm/model_executor/models/bart.py b/vllm/model_executor/models/bart.py index 93452696dca55..82684dfa730e4 100644 --- a/vllm/model_executor/models/bart.py +++ b/vllm/model_executor/models/bart.py @@ -43,6 +43,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors +from .interfaces import SupportsV0Only from .utils import maybe_prefix logger = logging.get_logger(__name__) @@ -776,7 +777,7 @@ class BartModel(nn.Module): return decoder_outputs -class BartForConditionalGeneration(nn.Module): +class BartForConditionalGeneration(nn.Module, SupportsV0Only): base_model_prefix = "model" def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): diff --git a/vllm/model_executor/models/bert.py b/vllm/model_executor/models/bert.py index 4ff69527653d8..77b2ef0fce5f4 100644 --- a/vllm/model_executor/models/bert.py +++ b/vllm/model_executor/models/bert.py @@ -26,7 +26,7 @@ from vllm.sequence import IntermediateTensors, PoolerOutput from vllm.transformers_utils.config import ( get_cross_encoder_activation_function) -from .interfaces import SupportsCrossEncoding +from .interfaces import SupportsCrossEncoding, SupportsV0Only from .utils import WeightsMapper, maybe_prefix @@ -385,7 +385,7 @@ class BertModel(nn.Module): return loaded_params -class BertEmbeddingModel(nn.Module): +class BertEmbeddingModel(nn.Module, SupportsV0Only): """A model that uses Bert to provide embedding functionalities. This class encapsulates the BertModel and provides an interface for diff --git a/vllm/model_executor/models/florence2.py b/vllm/model_executor/models/florence2.py index c51fcf3d438bc..6fa1bb80995d6 100644 --- a/vllm/model_executor/models/florence2.py +++ b/vllm/model_executor/models/florence2.py @@ -29,7 +29,7 @@ from vllm.multimodal.processing import (BaseProcessingInfo, from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.sequence import IntermediateTensors -from .interfaces import SupportsMultiModal +from .interfaces import SupportsMultiModal, SupportsV0Only from .utils import AutoWeightsLoader, flatten_bn, merge_multimodal_embeddings @@ -651,7 +651,7 @@ class Florence2LanguageModel(nn.Module): return decoder_outputs -class Florence2LanguageForConditionalGeneration(nn.Module): +class Florence2LanguageForConditionalGeneration(nn.Module, SupportsV0Only): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() diff --git a/vllm/model_executor/models/gritlm.py b/vllm/model_executor/models/gritlm.py index 16223953ff839..2984f22412864 100644 --- a/vllm/model_executor/models/gritlm.py +++ b/vllm/model_executor/models/gritlm.py @@ -19,6 +19,8 @@ from vllm.sequence import (IntermediateTensors, PoolerOutput, PoolingSequenceGroupOutput) from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config +from .interfaces import SupportsV0Only + logger = init_logger(__name__) @@ -177,7 +179,7 @@ class GritLMPooler(nn.Module): return PoolerOutput(outputs=pooled_outputs) -class GritLM(LlamaForCausalLM): +class GritLM(LlamaForCausalLM, SupportsV0Only): """This class implements the embedding model for parasail-ai/GritLM-7B-vllm. The class inherits from LlamaForCausalLM and provides a custom pooling diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py index 47bd05f140c81..fb3ceb005295d 100644 --- a/vllm/model_executor/models/interfaces.py +++ b/vllm/model_executor/models/interfaces.py @@ -498,3 +498,29 @@ def supports_transcription( return isinstance(model, SupportsTranscription) return isinstance(model, SupportsTranscription) + + +@runtime_checkable +class SupportsV0Only(Protocol): + """Models with this interface are not compatible with V1 vLLM.""" + + supports_v0_only: ClassVar[Literal[True]] = True + + +@overload +def supports_v0_only(model: Type[object]) -> TypeIs[Type[SupportsV0Only]]: + ... + + +@overload +def supports_v0_only(model: object) -> TypeIs[SupportsV0Only]: + ... + + +def supports_v0_only( + model: Union[Type[object], object], +) -> Union[TypeIs[Type[SupportsV0Only]], TypeIs[SupportsV0Only]]: + if isinstance(model, type): + return isinstance(model, SupportsV0Only) + + return isinstance(model, SupportsV0Only) diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index 14e56df6cadf8..58eccd6a6b87d 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -30,7 +30,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors, PoolerOutput from vllm.utils import LayerBlockType -from .interfaces import HasInnerState, IsHybrid, SupportsLoRA, SupportsPP +from .interfaces import (HasInnerState, IsHybrid, SupportsLoRA, SupportsPP, + SupportsV0Only) from .utils import (is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) @@ -353,7 +354,7 @@ class JambaModel(nn.Module): class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, - IsHybrid): + IsHybrid, SupportsV0Only): packed_modules_mapping = { "qkv_proj": [ "q_proj", diff --git a/vllm/model_executor/models/mamba.py b/vllm/model_executor/models/mamba.py index 9f1cd8c29a5a0..46b9182f2d79b 100644 --- a/vllm/model_executor/models/mamba.py +++ b/vllm/model_executor/models/mamba.py @@ -19,7 +19,8 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.interfaces import (HasInnerState, - IsAttentionFree, SupportsPP) + IsAttentionFree, SupportsPP, + SupportsV0Only) from vllm.model_executor.models.mamba_cache import (MambaCacheManager, MambaCacheParams) from vllm.model_executor.sampling_metadata import SamplingMetadata @@ -155,7 +156,8 @@ class MambaModel(nn.Module): return hidden_states -class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP): +class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP, + SupportsV0Only): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config diff --git a/vllm/model_executor/models/mamba2.py b/vllm/model_executor/models/mamba2.py index 266cdc243ac44..da5cbddbcbc58 100644 --- a/vllm/model_executor/models/mamba2.py +++ b/vllm/model_executor/models/mamba2.py @@ -22,7 +22,8 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.interfaces import (HasInnerState, - IsAttentionFree) + IsAttentionFree, + SupportsV0Only) from vllm.model_executor.models.mamba_cache import (MambaCacheManager, MambaCacheParams) from vllm.model_executor.sampling_metadata import SamplingMetadata @@ -174,7 +175,8 @@ class Mamba2Model(nn.Module): return hidden_states -class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree): +class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree, + SupportsV0Only): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index fb6ea53acf9e4..1816bf5d008d7 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -63,7 +63,8 @@ from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors from .idefics2_vision_model import Idefics2VisionTransformer -from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP +from .interfaces import (SupportsLoRA, SupportsMultiModal, SupportsPP, + SupportsV0Only) from .utils import AutoWeightsLoader, maybe_prefix CPU_DEVICE = torch.device("cpu") @@ -804,7 +805,8 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]): return result -class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP): +class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP, + SupportsV0Only): """ The abstract class of MiniCPMV can only be inherited, but cannot be instantiated. diff --git a/vllm/model_executor/models/mllama.py b/vllm/model_executor/models/mllama.py index 36e653e41e1bf..7122fea2b3a80 100644 --- a/vllm/model_executor/models/mllama.py +++ b/vllm/model_executor/models/mllama.py @@ -63,7 +63,7 @@ from vllm.multimodal.processing import (BaseProcessingInfo, from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from .clip import CLIPMLP -from .interfaces import SupportsMultiModal +from .interfaces import SupportsMultiModal, SupportsV0Only from .llama import LlamaDecoderLayer, LlamaMLP from .utils import maybe_prefix @@ -1128,7 +1128,8 @@ class MllamaForCausalLM(nn.Module): @MULTIMODAL_REGISTRY.register_processor(MllamaMultiModalProcessor, info=MllamaProcessingInfo, dummy_inputs=MllamaDummyInputsBuilder) -class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal): +class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal, + SupportsV0Only): packed_modules_mapping = { "qkv_proj": ["q_proj", "k_proj", "v_proj"], "gate_up_proj": ["gate_proj", "up_proj"] diff --git a/vllm/model_executor/models/paligemma.py b/vllm/model_executor/models/paligemma.py index 02d1861b8027c..9a1398c28dbcb 100644 --- a/vllm/model_executor/models/paligemma.py +++ b/vllm/model_executor/models/paligemma.py @@ -18,7 +18,7 @@ from vllm.multimodal.inputs import NestedTensors from vllm.sequence import IntermediateTensors from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config -from .interfaces import SupportsMultiModal, SupportsPP +from .interfaces import SupportsMultiModal, SupportsPP, SupportsV0Only from .siglip import (SiglipVisionModel, dummy_image_for_siglip, dummy_seq_data_for_siglip, get_max_siglip_image_tokens) from .utils import (AutoWeightsLoader, init_vllm_registered_model, @@ -136,7 +136,7 @@ class PaliGemmaMultiModalProjector(nn.Module): @INPUT_REGISTRY.register_dummy_data(dummy_data_for_paligemma) @INPUT_REGISTRY.register_input_processor(input_processor_for_paligemma) class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal, - SupportsPP): + SupportsPP, SupportsV0Only): packed_modules_mapping = { "qkv_proj": [ "q_proj", diff --git a/vllm/model_executor/models/prithvi_geospatial_mae.py b/vllm/model_executor/models/prithvi_geospatial_mae.py index bfa90e42733db..d922329b3a499 100644 --- a/vllm/model_executor/models/prithvi_geospatial_mae.py +++ b/vllm/model_executor/models/prithvi_geospatial_mae.py @@ -25,7 +25,8 @@ from transformers import BatchFeature from vllm.config import VllmConfig from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.interfaces import (IsAttentionFree, - SupportsMultiModal) + SupportsMultiModal, + SupportsV0Only) from vllm.model_executor.models.utils import AutoWeightsLoader from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY @@ -111,7 +112,8 @@ class PrithviGeoSpatialMAEMultiModalProcessor(BaseMultiModalProcessor): PrithviGeoSpatialMAEMultiModalProcessor, info=PrithviGeoSpatialMAEProcessingInfo, dummy_inputs=PrithviGeoSpatialMAEInputBuilder) -class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal): +class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal, + SupportsV0Only): """ Prithvi Masked Autoencoder""" def _instantiate_model(self, config: dict) -> Optional[nn.Module]: diff --git a/vllm/model_executor/models/qwen2_rm.py b/vllm/model_executor/models/qwen2_rm.py index 21cc9e8ed1c6b..90f799e6734ed 100644 --- a/vllm/model_executor/models/qwen2_rm.py +++ b/vllm/model_executor/models/qwen2_rm.py @@ -17,7 +17,7 @@ from vllm.model_executor.layers.pooler import Pooler, PoolingType, SimplePooler from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.sequence import IntermediateTensors, PoolerOutput -from .interfaces import SupportsLoRA, SupportsPP +from .interfaces import SupportsLoRA, SupportsPP, SupportsV0Only from .qwen2 import Qwen2Model from .utils import AutoWeightsLoader, maybe_prefix @@ -33,7 +33,8 @@ class ReLU(nn.Module): return self.activation(input) -class Qwen2RewardBaseModel(nn.Module, SupportsLoRA, SupportsPP): +class Qwen2RewardBaseModel(nn.Module, SupportsLoRA, SupportsPP, + SupportsV0Only): packed_modules_mapping = { "qkv_proj": [ "q_proj", diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 75e31d557dd10..028658b526446 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -22,7 +22,7 @@ from vllm.logger import init_logger from .interfaces import (has_inner_state, is_attention_free, is_hybrid, supports_cross_encoding, supports_multimodal, - supports_pp, supports_transcription) + supports_pp, supports_transcription, supports_v0_only) from .interfaces_base import is_text_generation_model logger = init_logger(__name__) @@ -228,6 +228,7 @@ class _ModelInfo: is_attention_free: bool is_hybrid: bool supports_transcription: bool + supports_v0_only: bool @staticmethod def from_model_cls(model: Type[nn.Module]) -> "_ModelInfo": @@ -241,7 +242,9 @@ class _ModelInfo: has_inner_state=has_inner_state(model), is_attention_free=is_attention_free(model), is_hybrid=is_hybrid(model), - supports_transcription=supports_transcription(model)) + supports_transcription=supports_transcription(model), + supports_v0_only=supports_v0_only(model), + ) class _BaseRegisteredModel(ABC): @@ -504,6 +507,13 @@ class _ModelRegistry: model_cls, _ = self.inspect_model_cls(architectures) return model_cls.supports_transcription + def is_v1_compatible( + self, + architectures: Union[str, List[str]], + ) -> bool: + model_cls, _ = self.inspect_model_cls(architectures) + return not model_cls.supports_v0_only + ModelRegistry = _ModelRegistry({ model_arch: diff --git a/vllm/model_executor/models/roberta.py b/vllm/model_executor/models/roberta.py index f86fa268072db..ba92eef12707c 100644 --- a/vllm/model_executor/models/roberta.py +++ b/vllm/model_executor/models/roberta.py @@ -19,7 +19,7 @@ from vllm.sequence import IntermediateTensors, PoolerOutput from vllm.transformers_utils.config import ( get_cross_encoder_activation_function) -from .interfaces import SupportsCrossEncoding +from .interfaces import SupportsCrossEncoding, SupportsV0Only def roberta_task_weights_filter( @@ -191,7 +191,8 @@ class RobertaEmbeddingModel(BertEmbeddingModel): assert len(loaded), "Unable to load RobertaEmbeddingModel" -class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding): +class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding, + SupportsV0Only): """A model that uses Roberta to provide embedding functionalities. This class encapsulates the BertModel and provides an interface for diff --git a/vllm/model_executor/models/whisper.py b/vllm/model_executor/models/whisper.py index 2da8c5c8b0e2e..656e5fc6dcf30 100644 --- a/vllm/model_executor/models/whisper.py +++ b/vllm/model_executor/models/whisper.py @@ -34,7 +34,8 @@ from vllm.multimodal.processing import (BaseProcessingInfo, PromptReplacement, PromptUpdate) from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs -from .interfaces import SupportsMultiModal, SupportsTranscription +from .interfaces import (SupportsMultiModal, SupportsTranscription, + SupportsV0Only) from .utils import (AutoWeightsLoader, WeightsMapper, cast_overflow_tensors, make_layers) @@ -643,7 +644,7 @@ class WhisperMultiModalProcessor( info=WhisperProcessingInfo, dummy_inputs=WhisperDummyInputsBuilder) class WhisperForConditionalGeneration(nn.Module, SupportsTranscription, - SupportsMultiModal): + SupportsMultiModal, SupportsV0Only): packed_modules_mapping = { "self_attn.qkv_proj": [ "self_attn.q_proj",