mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-08 05:11:50 +08:00
[V1]SupportsV0Only protocol for model definitions (#13959)
Signed-off-by: Roger Wang <ywang@roblox.com>
This commit is contained in:
parent
67fc426845
commit
6c85da3a18
@ -1039,6 +1039,11 @@ class ModelConfig:
|
||||
def runner_type(self) -> RunnerType:
|
||||
return _TASK_RUNNER[self.task]
|
||||
|
||||
@property
|
||||
def is_v1_compatible(self) -> bool:
|
||||
architectures = getattr(self.hf_config, "architectures", [])
|
||||
return ModelRegistry.is_v1_compatible(architectures)
|
||||
|
||||
|
||||
class CacheConfig:
|
||||
"""Configuration for the KV cache.
|
||||
|
||||
@ -1,8 +1,9 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from .interfaces import (HasInnerState, SupportsLoRA, SupportsMultiModal,
|
||||
SupportsPP, has_inner_state, supports_lora,
|
||||
supports_multimodal, supports_pp)
|
||||
SupportsPP, SupportsV0Only, has_inner_state,
|
||||
supports_lora, supports_multimodal, supports_pp,
|
||||
supports_v0_only)
|
||||
from .interfaces_base import (VllmModelForPooling, VllmModelForTextGeneration,
|
||||
is_pooling_model, is_text_generation_model)
|
||||
from .registry import ModelRegistry
|
||||
@ -21,4 +22,6 @@ __all__ = [
|
||||
"supports_multimodal",
|
||||
"SupportsPP",
|
||||
"supports_pp",
|
||||
"SupportsV0Only",
|
||||
"supports_v0_only",
|
||||
]
|
||||
|
||||
@ -32,7 +32,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import LayerBlockType
|
||||
|
||||
from .interfaces import HasInnerState, IsHybrid, SupportsLoRA, SupportsPP
|
||||
from .interfaces import (HasInnerState, IsHybrid, SupportsLoRA, SupportsPP,
|
||||
SupportsV0Only)
|
||||
from .utils import (is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers,
|
||||
maybe_prefix)
|
||||
@ -366,7 +367,7 @@ class BambaModel(nn.Module):
|
||||
|
||||
|
||||
class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
|
||||
IsHybrid):
|
||||
IsHybrid, SupportsV0Only):
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
|
||||
@ -43,6 +43,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import SupportsV0Only
|
||||
from .utils import maybe_prefix
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@ -776,7 +777,7 @@ class BartModel(nn.Module):
|
||||
return decoder_outputs
|
||||
|
||||
|
||||
class BartForConditionalGeneration(nn.Module):
|
||||
class BartForConditionalGeneration(nn.Module, SupportsV0Only):
|
||||
base_model_prefix = "model"
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
|
||||
@ -26,7 +26,7 @@ from vllm.sequence import IntermediateTensors, PoolerOutput
|
||||
from vllm.transformers_utils.config import (
|
||||
get_cross_encoder_activation_function)
|
||||
|
||||
from .interfaces import SupportsCrossEncoding
|
||||
from .interfaces import SupportsCrossEncoding, SupportsV0Only
|
||||
from .utils import WeightsMapper, maybe_prefix
|
||||
|
||||
|
||||
@ -385,7 +385,7 @@ class BertModel(nn.Module):
|
||||
return loaded_params
|
||||
|
||||
|
||||
class BertEmbeddingModel(nn.Module):
|
||||
class BertEmbeddingModel(nn.Module, SupportsV0Only):
|
||||
"""A model that uses Bert to provide embedding functionalities.
|
||||
|
||||
This class encapsulates the BertModel and provides an interface for
|
||||
|
||||
@ -29,7 +29,7 @@ from vllm.multimodal.processing import (BaseProcessingInfo,
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import SupportsMultiModal
|
||||
from .interfaces import SupportsMultiModal, SupportsV0Only
|
||||
from .utils import AutoWeightsLoader, flatten_bn, merge_multimodal_embeddings
|
||||
|
||||
|
||||
@ -651,7 +651,7 @@ class Florence2LanguageModel(nn.Module):
|
||||
return decoder_outputs
|
||||
|
||||
|
||||
class Florence2LanguageForConditionalGeneration(nn.Module):
|
||||
class Florence2LanguageForConditionalGeneration(nn.Module, SupportsV0Only):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
|
||||
@ -19,6 +19,8 @@ from vllm.sequence import (IntermediateTensors, PoolerOutput,
|
||||
PoolingSequenceGroupOutput)
|
||||
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
|
||||
|
||||
from .interfaces import SupportsV0Only
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@ -177,7 +179,7 @@ class GritLMPooler(nn.Module):
|
||||
return PoolerOutput(outputs=pooled_outputs)
|
||||
|
||||
|
||||
class GritLM(LlamaForCausalLM):
|
||||
class GritLM(LlamaForCausalLM, SupportsV0Only):
|
||||
"""This class implements the embedding model for parasail-ai/GritLM-7B-vllm.
|
||||
|
||||
The class inherits from LlamaForCausalLM and provides a custom pooling
|
||||
|
||||
@ -498,3 +498,29 @@ def supports_transcription(
|
||||
return isinstance(model, SupportsTranscription)
|
||||
|
||||
return isinstance(model, SupportsTranscription)
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class SupportsV0Only(Protocol):
|
||||
"""Models with this interface are not compatible with V1 vLLM."""
|
||||
|
||||
supports_v0_only: ClassVar[Literal[True]] = True
|
||||
|
||||
|
||||
@overload
|
||||
def supports_v0_only(model: Type[object]) -> TypeIs[Type[SupportsV0Only]]:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def supports_v0_only(model: object) -> TypeIs[SupportsV0Only]:
|
||||
...
|
||||
|
||||
|
||||
def supports_v0_only(
|
||||
model: Union[Type[object], object],
|
||||
) -> Union[TypeIs[Type[SupportsV0Only]], TypeIs[SupportsV0Only]]:
|
||||
if isinstance(model, type):
|
||||
return isinstance(model, SupportsV0Only)
|
||||
|
||||
return isinstance(model, SupportsV0Only)
|
||||
|
||||
@ -30,7 +30,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors, PoolerOutput
|
||||
from vllm.utils import LayerBlockType
|
||||
|
||||
from .interfaces import HasInnerState, IsHybrid, SupportsLoRA, SupportsPP
|
||||
from .interfaces import (HasInnerState, IsHybrid, SupportsLoRA, SupportsPP,
|
||||
SupportsV0Only)
|
||||
from .utils import (is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers,
|
||||
maybe_prefix)
|
||||
@ -353,7 +354,7 @@ class JambaModel(nn.Module):
|
||||
|
||||
|
||||
class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
|
||||
IsHybrid):
|
||||
IsHybrid, SupportsV0Only):
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
|
||||
@ -19,7 +19,8 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.interfaces import (HasInnerState,
|
||||
IsAttentionFree, SupportsPP)
|
||||
IsAttentionFree, SupportsPP,
|
||||
SupportsV0Only)
|
||||
from vllm.model_executor.models.mamba_cache import (MambaCacheManager,
|
||||
MambaCacheParams)
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
@ -155,7 +156,8 @@ class MambaModel(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP):
|
||||
class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP,
|
||||
SupportsV0Only):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
config = vllm_config.model_config.hf_config
|
||||
|
||||
@ -22,7 +22,8 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.interfaces import (HasInnerState,
|
||||
IsAttentionFree)
|
||||
IsAttentionFree,
|
||||
SupportsV0Only)
|
||||
from vllm.model_executor.models.mamba_cache import (MambaCacheManager,
|
||||
MambaCacheParams)
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
@ -174,7 +175,8 @@ class Mamba2Model(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
|
||||
class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree,
|
||||
SupportsV0Only):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
config = vllm_config.model_config.hf_config
|
||||
|
||||
@ -63,7 +63,8 @@ from vllm.platforms import current_platform
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .idefics2_vision_model import Idefics2VisionTransformer
|
||||
from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP
|
||||
from .interfaces import (SupportsLoRA, SupportsMultiModal, SupportsPP,
|
||||
SupportsV0Only)
|
||||
from .utils import AutoWeightsLoader, maybe_prefix
|
||||
|
||||
CPU_DEVICE = torch.device("cpu")
|
||||
@ -804,7 +805,8 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
|
||||
return result
|
||||
|
||||
|
||||
class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP,
|
||||
SupportsV0Only):
|
||||
"""
|
||||
The abstract class of MiniCPMV can only be inherited, but cannot be
|
||||
instantiated.
|
||||
|
||||
@ -63,7 +63,7 @@ from vllm.multimodal.processing import (BaseProcessingInfo,
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
||||
|
||||
from .clip import CLIPMLP
|
||||
from .interfaces import SupportsMultiModal
|
||||
from .interfaces import SupportsMultiModal, SupportsV0Only
|
||||
from .llama import LlamaDecoderLayer, LlamaMLP
|
||||
from .utils import maybe_prefix
|
||||
|
||||
@ -1128,7 +1128,8 @@ class MllamaForCausalLM(nn.Module):
|
||||
@MULTIMODAL_REGISTRY.register_processor(MllamaMultiModalProcessor,
|
||||
info=MllamaProcessingInfo,
|
||||
dummy_inputs=MllamaDummyInputsBuilder)
|
||||
class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
SupportsV0Only):
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
|
||||
"gate_up_proj": ["gate_proj", "up_proj"]
|
||||
|
||||
@ -18,7 +18,7 @@ from vllm.multimodal.inputs import NestedTensors
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
|
||||
|
||||
from .interfaces import SupportsMultiModal, SupportsPP
|
||||
from .interfaces import SupportsMultiModal, SupportsPP, SupportsV0Only
|
||||
from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
|
||||
dummy_seq_data_for_siglip, get_max_siglip_image_tokens)
|
||||
from .utils import (AutoWeightsLoader, init_vllm_registered_model,
|
||||
@ -136,7 +136,7 @@ class PaliGemmaMultiModalProjector(nn.Module):
|
||||
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_paligemma)
|
||||
@INPUT_REGISTRY.register_input_processor(input_processor_for_paligemma)
|
||||
class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
SupportsPP):
|
||||
SupportsPP, SupportsV0Only):
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
|
||||
@ -25,7 +25,8 @@ from transformers import BatchFeature
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.interfaces import (IsAttentionFree,
|
||||
SupportsMultiModal)
|
||||
SupportsMultiModal,
|
||||
SupportsV0Only)
|
||||
from vllm.model_executor.models.utils import AutoWeightsLoader
|
||||
from vllm.model_executor.pooling_metadata import PoolingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
@ -111,7 +112,8 @@ class PrithviGeoSpatialMAEMultiModalProcessor(BaseMultiModalProcessor):
|
||||
PrithviGeoSpatialMAEMultiModalProcessor,
|
||||
info=PrithviGeoSpatialMAEProcessingInfo,
|
||||
dummy_inputs=PrithviGeoSpatialMAEInputBuilder)
|
||||
class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal):
|
||||
class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal,
|
||||
SupportsV0Only):
|
||||
""" Prithvi Masked Autoencoder"""
|
||||
|
||||
def _instantiate_model(self, config: dict) -> Optional[nn.Module]:
|
||||
|
||||
@ -17,7 +17,7 @@ from vllm.model_executor.layers.pooler import Pooler, PoolingType, SimplePooler
|
||||
from vllm.model_executor.pooling_metadata import PoolingMetadata
|
||||
from vllm.sequence import IntermediateTensors, PoolerOutput
|
||||
|
||||
from .interfaces import SupportsLoRA, SupportsPP
|
||||
from .interfaces import SupportsLoRA, SupportsPP, SupportsV0Only
|
||||
from .qwen2 import Qwen2Model
|
||||
from .utils import AutoWeightsLoader, maybe_prefix
|
||||
|
||||
@ -33,7 +33,8 @@ class ReLU(nn.Module):
|
||||
return self.activation(input)
|
||||
|
||||
|
||||
class Qwen2RewardBaseModel(nn.Module, SupportsLoRA, SupportsPP):
|
||||
class Qwen2RewardBaseModel(nn.Module, SupportsLoRA, SupportsPP,
|
||||
SupportsV0Only):
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
|
||||
@ -22,7 +22,7 @@ from vllm.logger import init_logger
|
||||
|
||||
from .interfaces import (has_inner_state, is_attention_free, is_hybrid,
|
||||
supports_cross_encoding, supports_multimodal,
|
||||
supports_pp, supports_transcription)
|
||||
supports_pp, supports_transcription, supports_v0_only)
|
||||
from .interfaces_base import is_text_generation_model
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -228,6 +228,7 @@ class _ModelInfo:
|
||||
is_attention_free: bool
|
||||
is_hybrid: bool
|
||||
supports_transcription: bool
|
||||
supports_v0_only: bool
|
||||
|
||||
@staticmethod
|
||||
def from_model_cls(model: Type[nn.Module]) -> "_ModelInfo":
|
||||
@ -241,7 +242,9 @@ class _ModelInfo:
|
||||
has_inner_state=has_inner_state(model),
|
||||
is_attention_free=is_attention_free(model),
|
||||
is_hybrid=is_hybrid(model),
|
||||
supports_transcription=supports_transcription(model))
|
||||
supports_transcription=supports_transcription(model),
|
||||
supports_v0_only=supports_v0_only(model),
|
||||
)
|
||||
|
||||
|
||||
class _BaseRegisteredModel(ABC):
|
||||
@ -504,6 +507,13 @@ class _ModelRegistry:
|
||||
model_cls, _ = self.inspect_model_cls(architectures)
|
||||
return model_cls.supports_transcription
|
||||
|
||||
def is_v1_compatible(
|
||||
self,
|
||||
architectures: Union[str, List[str]],
|
||||
) -> bool:
|
||||
model_cls, _ = self.inspect_model_cls(architectures)
|
||||
return not model_cls.supports_v0_only
|
||||
|
||||
|
||||
ModelRegistry = _ModelRegistry({
|
||||
model_arch:
|
||||
|
||||
@ -19,7 +19,7 @@ from vllm.sequence import IntermediateTensors, PoolerOutput
|
||||
from vllm.transformers_utils.config import (
|
||||
get_cross_encoder_activation_function)
|
||||
|
||||
from .interfaces import SupportsCrossEncoding
|
||||
from .interfaces import SupportsCrossEncoding, SupportsV0Only
|
||||
|
||||
|
||||
def roberta_task_weights_filter(
|
||||
@ -191,7 +191,8 @@ class RobertaEmbeddingModel(BertEmbeddingModel):
|
||||
assert len(loaded), "Unable to load RobertaEmbeddingModel"
|
||||
|
||||
|
||||
class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding):
|
||||
class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding,
|
||||
SupportsV0Only):
|
||||
"""A model that uses Roberta to provide embedding functionalities.
|
||||
|
||||
This class encapsulates the BertModel and provides an interface for
|
||||
|
||||
@ -34,7 +34,8 @@ from vllm.multimodal.processing import (BaseProcessingInfo,
|
||||
PromptReplacement, PromptUpdate)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
||||
|
||||
from .interfaces import SupportsMultiModal, SupportsTranscription
|
||||
from .interfaces import (SupportsMultiModal, SupportsTranscription,
|
||||
SupportsV0Only)
|
||||
from .utils import (AutoWeightsLoader, WeightsMapper, cast_overflow_tensors,
|
||||
make_layers)
|
||||
|
||||
@ -643,7 +644,7 @@ class WhisperMultiModalProcessor(
|
||||
info=WhisperProcessingInfo,
|
||||
dummy_inputs=WhisperDummyInputsBuilder)
|
||||
class WhisperForConditionalGeneration(nn.Module, SupportsTranscription,
|
||||
SupportsMultiModal):
|
||||
SupportsMultiModal, SupportsV0Only):
|
||||
packed_modules_mapping = {
|
||||
"self_attn.qkv_proj": [
|
||||
"self_attn.q_proj",
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user