[V1]SupportsV0Only protocol for model definitions (#13959)

Signed-off-by: Roger Wang <ywang@roblox.com>
This commit is contained in:
Roger Wang 2025-02-27 17:02:15 -08:00 committed by GitHub
parent 67fc426845
commit 6c85da3a18
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
19 changed files with 93 additions and 32 deletions

View File

@ -1039,6 +1039,11 @@ class ModelConfig:
def runner_type(self) -> RunnerType:
return _TASK_RUNNER[self.task]
@property
def is_v1_compatible(self) -> bool:
architectures = getattr(self.hf_config, "architectures", [])
return ModelRegistry.is_v1_compatible(architectures)
class CacheConfig:
"""Configuration for the KV cache.

View File

@ -1,8 +1,9 @@
# SPDX-License-Identifier: Apache-2.0
from .interfaces import (HasInnerState, SupportsLoRA, SupportsMultiModal,
SupportsPP, has_inner_state, supports_lora,
supports_multimodal, supports_pp)
SupportsPP, SupportsV0Only, has_inner_state,
supports_lora, supports_multimodal, supports_pp,
supports_v0_only)
from .interfaces_base import (VllmModelForPooling, VllmModelForTextGeneration,
is_pooling_model, is_text_generation_model)
from .registry import ModelRegistry
@ -21,4 +22,6 @@ __all__ = [
"supports_multimodal",
"SupportsPP",
"supports_pp",
"SupportsV0Only",
"supports_v0_only",
]

View File

@ -32,7 +32,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from vllm.utils import LayerBlockType
from .interfaces import HasInnerState, IsHybrid, SupportsLoRA, SupportsPP
from .interfaces import (HasInnerState, IsHybrid, SupportsLoRA, SupportsPP,
SupportsV0Only)
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
@ -366,7 +367,7 @@ class BambaModel(nn.Module):
class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
IsHybrid):
IsHybrid, SupportsV0Only):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",

View File

@ -43,6 +43,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsV0Only
from .utils import maybe_prefix
logger = logging.get_logger(__name__)
@ -776,7 +777,7 @@ class BartModel(nn.Module):
return decoder_outputs
class BartForConditionalGeneration(nn.Module):
class BartForConditionalGeneration(nn.Module, SupportsV0Only):
base_model_prefix = "model"
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):

View File

@ -26,7 +26,7 @@ from vllm.sequence import IntermediateTensors, PoolerOutput
from vllm.transformers_utils.config import (
get_cross_encoder_activation_function)
from .interfaces import SupportsCrossEncoding
from .interfaces import SupportsCrossEncoding, SupportsV0Only
from .utils import WeightsMapper, maybe_prefix
@ -385,7 +385,7 @@ class BertModel(nn.Module):
return loaded_params
class BertEmbeddingModel(nn.Module):
class BertEmbeddingModel(nn.Module, SupportsV0Only):
"""A model that uses Bert to provide embedding functionalities.
This class encapsulates the BertModel and provides an interface for

View File

@ -29,7 +29,7 @@ from vllm.multimodal.processing import (BaseProcessingInfo,
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsMultiModal
from .interfaces import SupportsMultiModal, SupportsV0Only
from .utils import AutoWeightsLoader, flatten_bn, merge_multimodal_embeddings
@ -651,7 +651,7 @@ class Florence2LanguageModel(nn.Module):
return decoder_outputs
class Florence2LanguageForConditionalGeneration(nn.Module):
class Florence2LanguageForConditionalGeneration(nn.Module, SupportsV0Only):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()

View File

@ -19,6 +19,8 @@ from vllm.sequence import (IntermediateTensors, PoolerOutput,
PoolingSequenceGroupOutput)
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
from .interfaces import SupportsV0Only
logger = init_logger(__name__)
@ -177,7 +179,7 @@ class GritLMPooler(nn.Module):
return PoolerOutput(outputs=pooled_outputs)
class GritLM(LlamaForCausalLM):
class GritLM(LlamaForCausalLM, SupportsV0Only):
"""This class implements the embedding model for parasail-ai/GritLM-7B-vllm.
The class inherits from LlamaForCausalLM and provides a custom pooling

View File

@ -498,3 +498,29 @@ def supports_transcription(
return isinstance(model, SupportsTranscription)
return isinstance(model, SupportsTranscription)
@runtime_checkable
class SupportsV0Only(Protocol):
"""Models with this interface are not compatible with V1 vLLM."""
supports_v0_only: ClassVar[Literal[True]] = True
@overload
def supports_v0_only(model: Type[object]) -> TypeIs[Type[SupportsV0Only]]:
...
@overload
def supports_v0_only(model: object) -> TypeIs[SupportsV0Only]:
...
def supports_v0_only(
model: Union[Type[object], object],
) -> Union[TypeIs[Type[SupportsV0Only]], TypeIs[SupportsV0Only]]:
if isinstance(model, type):
return isinstance(model, SupportsV0Only)
return isinstance(model, SupportsV0Only)

View File

@ -30,7 +30,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors, PoolerOutput
from vllm.utils import LayerBlockType
from .interfaces import HasInnerState, IsHybrid, SupportsLoRA, SupportsPP
from .interfaces import (HasInnerState, IsHybrid, SupportsLoRA, SupportsPP,
SupportsV0Only)
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
@ -353,7 +354,7 @@ class JambaModel(nn.Module):
class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
IsHybrid):
IsHybrid, SupportsV0Only):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",

View File

@ -19,7 +19,8 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.interfaces import (HasInnerState,
IsAttentionFree, SupportsPP)
IsAttentionFree, SupportsPP,
SupportsV0Only)
from vllm.model_executor.models.mamba_cache import (MambaCacheManager,
MambaCacheParams)
from vllm.model_executor.sampling_metadata import SamplingMetadata
@ -155,7 +156,8 @@ class MambaModel(nn.Module):
return hidden_states
class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP):
class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP,
SupportsV0Only):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
config = vllm_config.model_config.hf_config

View File

@ -22,7 +22,8 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.interfaces import (HasInnerState,
IsAttentionFree)
IsAttentionFree,
SupportsV0Only)
from vllm.model_executor.models.mamba_cache import (MambaCacheManager,
MambaCacheParams)
from vllm.model_executor.sampling_metadata import SamplingMetadata
@ -174,7 +175,8 @@ class Mamba2Model(nn.Module):
return hidden_states
class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree,
SupportsV0Only):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
config = vllm_config.model_config.hf_config

View File

@ -63,7 +63,8 @@ from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from .idefics2_vision_model import Idefics2VisionTransformer
from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP
from .interfaces import (SupportsLoRA, SupportsMultiModal, SupportsPP,
SupportsV0Only)
from .utils import AutoWeightsLoader, maybe_prefix
CPU_DEVICE = torch.device("cpu")
@ -804,7 +805,8 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
return result
class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP,
SupportsV0Only):
"""
The abstract class of MiniCPMV can only be inherited, but cannot be
instantiated.

View File

@ -63,7 +63,7 @@ from vllm.multimodal.processing import (BaseProcessingInfo,
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from .clip import CLIPMLP
from .interfaces import SupportsMultiModal
from .interfaces import SupportsMultiModal, SupportsV0Only
from .llama import LlamaDecoderLayer, LlamaMLP
from .utils import maybe_prefix
@ -1128,7 +1128,8 @@ class MllamaForCausalLM(nn.Module):
@MULTIMODAL_REGISTRY.register_processor(MllamaMultiModalProcessor,
info=MllamaProcessingInfo,
dummy_inputs=MllamaDummyInputsBuilder)
class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal):
class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsV0Only):
packed_modules_mapping = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
"gate_up_proj": ["gate_proj", "up_proj"]

View File

@ -18,7 +18,7 @@ from vllm.multimodal.inputs import NestedTensors
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
from .interfaces import SupportsMultiModal, SupportsPP
from .interfaces import SupportsMultiModal, SupportsPP, SupportsV0Only
from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
dummy_seq_data_for_siglip, get_max_siglip_image_tokens)
from .utils import (AutoWeightsLoader, init_vllm_registered_model,
@ -136,7 +136,7 @@ class PaliGemmaMultiModalProjector(nn.Module):
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_paligemma)
@INPUT_REGISTRY.register_input_processor(input_processor_for_paligemma)
class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsPP):
SupportsPP, SupportsV0Only):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",

View File

@ -25,7 +25,8 @@ from transformers import BatchFeature
from vllm.config import VllmConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.interfaces import (IsAttentionFree,
SupportsMultiModal)
SupportsMultiModal,
SupportsV0Only)
from vllm.model_executor.models.utils import AutoWeightsLoader
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
@ -111,7 +112,8 @@ class PrithviGeoSpatialMAEMultiModalProcessor(BaseMultiModalProcessor):
PrithviGeoSpatialMAEMultiModalProcessor,
info=PrithviGeoSpatialMAEProcessingInfo,
dummy_inputs=PrithviGeoSpatialMAEInputBuilder)
class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal):
class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal,
SupportsV0Only):
""" Prithvi Masked Autoencoder"""
def _instantiate_model(self, config: dict) -> Optional[nn.Module]:

View File

@ -17,7 +17,7 @@ from vllm.model_executor.layers.pooler import Pooler, PoolingType, SimplePooler
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.sequence import IntermediateTensors, PoolerOutput
from .interfaces import SupportsLoRA, SupportsPP
from .interfaces import SupportsLoRA, SupportsPP, SupportsV0Only
from .qwen2 import Qwen2Model
from .utils import AutoWeightsLoader, maybe_prefix
@ -33,7 +33,8 @@ class ReLU(nn.Module):
return self.activation(input)
class Qwen2RewardBaseModel(nn.Module, SupportsLoRA, SupportsPP):
class Qwen2RewardBaseModel(nn.Module, SupportsLoRA, SupportsPP,
SupportsV0Only):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",

View File

@ -22,7 +22,7 @@ from vllm.logger import init_logger
from .interfaces import (has_inner_state, is_attention_free, is_hybrid,
supports_cross_encoding, supports_multimodal,
supports_pp, supports_transcription)
supports_pp, supports_transcription, supports_v0_only)
from .interfaces_base import is_text_generation_model
logger = init_logger(__name__)
@ -228,6 +228,7 @@ class _ModelInfo:
is_attention_free: bool
is_hybrid: bool
supports_transcription: bool
supports_v0_only: bool
@staticmethod
def from_model_cls(model: Type[nn.Module]) -> "_ModelInfo":
@ -241,7 +242,9 @@ class _ModelInfo:
has_inner_state=has_inner_state(model),
is_attention_free=is_attention_free(model),
is_hybrid=is_hybrid(model),
supports_transcription=supports_transcription(model))
supports_transcription=supports_transcription(model),
supports_v0_only=supports_v0_only(model),
)
class _BaseRegisteredModel(ABC):
@ -504,6 +507,13 @@ class _ModelRegistry:
model_cls, _ = self.inspect_model_cls(architectures)
return model_cls.supports_transcription
def is_v1_compatible(
self,
architectures: Union[str, List[str]],
) -> bool:
model_cls, _ = self.inspect_model_cls(architectures)
return not model_cls.supports_v0_only
ModelRegistry = _ModelRegistry({
model_arch:

View File

@ -19,7 +19,7 @@ from vllm.sequence import IntermediateTensors, PoolerOutput
from vllm.transformers_utils.config import (
get_cross_encoder_activation_function)
from .interfaces import SupportsCrossEncoding
from .interfaces import SupportsCrossEncoding, SupportsV0Only
def roberta_task_weights_filter(
@ -191,7 +191,8 @@ class RobertaEmbeddingModel(BertEmbeddingModel):
assert len(loaded), "Unable to load RobertaEmbeddingModel"
class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding):
class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding,
SupportsV0Only):
"""A model that uses Roberta to provide embedding functionalities.
This class encapsulates the BertModel and provides an interface for

View File

@ -34,7 +34,8 @@ from vllm.multimodal.processing import (BaseProcessingInfo,
PromptReplacement, PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from .interfaces import SupportsMultiModal, SupportsTranscription
from .interfaces import (SupportsMultiModal, SupportsTranscription,
SupportsV0Only)
from .utils import (AutoWeightsLoader, WeightsMapper, cast_overflow_tensors,
make_layers)
@ -643,7 +644,7 @@ class WhisperMultiModalProcessor(
info=WhisperProcessingInfo,
dummy_inputs=WhisperDummyInputsBuilder)
class WhisperForConditionalGeneration(nn.Module, SupportsTranscription,
SupportsMultiModal):
SupportsMultiModal, SupportsV0Only):
packed_modules_mapping = {
"self_attn.qkv_proj": [
"self_attn.q_proj",