diff --git a/vllm/config/model.py b/vllm/config/model.py index b32d820edd7b5..adb0dd9ac9f5c 100644 --- a/vllm/config/model.py +++ b/vllm/config/model.py @@ -1656,6 +1656,10 @@ class ModelConfig: def has_inner_state(self): return self._model_info.has_inner_state + @property + def supports_mamba_prefix_caching(self) -> bool: + return self._model_info.supports_mamba_prefix_caching + @property def use_mla(self) -> bool: return self.is_deepseek_mla and not envs.VLLM_MLA_DISABLE diff --git a/vllm/model_executor/models/bamba.py b/vllm/model_executor/models/bamba.py index 1a06f0659235e..151fb3b6acc46 100644 --- a/vllm/model_executor/models/bamba.py +++ b/vllm/model_executor/models/bamba.py @@ -37,7 +37,14 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.sequence import IntermediateTensors -from .interfaces import HasInnerState, IsHybrid, SupportsLoRA, SupportsPP, SupportsQuant +from .interfaces import ( + HasInnerState, + IsHybrid, + SupportsLoRA, + SupportsMambaPrefixCaching, + SupportsPP, + SupportsQuant, +) from .utils import ( AutoWeightsLoader, is_pp_missing_parameter, @@ -394,7 +401,13 @@ class BambaModel(nn.Module): class BambaForCausalLM( - nn.Module, HasInnerState, SupportsLoRA, SupportsPP, IsHybrid, SupportsQuant + nn.Module, + HasInnerState, + SupportsLoRA, + SupportsPP, + IsHybrid, + SupportsQuant, + SupportsMambaPrefixCaching, ): packed_modules_mapping = { "qkv_proj": [ diff --git a/vllm/model_executor/models/config.py b/vllm/model_executor/models/config.py index d2f9f1b0b5c06..493b74bddda7a 100644 --- a/vllm/model_executor/models/config.py +++ b/vllm/model_executor/models/config.py @@ -295,17 +295,8 @@ class MambaModelConfig(VerifyAndUpdateConfig): # override by prefix caching logic later) cache_config.mamba_block_size = model_config.max_model_len - # TODO(@tdoublep) find a better way to do this than whitelist - MAMBA2_MODELS = [ - "BambaForCausalLM", - "FalconH1ForCausalLM", - "GraniteMoeHybridForCausalLM", - "Mamba2ForCausalLM", - "NemotronHForCausalLM", - "Zamba2ForCausalLM", - ] if cache_config.enable_prefix_caching: - if model_config.architecture in MAMBA2_MODELS: + if model_config.supports_mamba_prefix_caching: logger.info( "Warning: Prefix caching is currently enabled. " "Its support for Mamba2 layers is experimental. " diff --git a/vllm/model_executor/models/falcon_h1.py b/vllm/model_executor/models/falcon_h1.py index 4e0b6b52fc647..8bf700b474a41 100644 --- a/vllm/model_executor/models/falcon_h1.py +++ b/vllm/model_executor/models/falcon_h1.py @@ -37,7 +37,13 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.sequence import IntermediateTensors -from .interfaces import HasInnerState, IsHybrid, SupportsLoRA, SupportsPP +from .interfaces import ( + HasInnerState, + IsHybrid, + SupportsLoRA, + SupportsMambaPrefixCaching, + SupportsPP, +) from .utils import ( PPMissingLayer, is_pp_missing_parameter, @@ -495,7 +501,14 @@ class FalconH1Model(nn.Module): return hidden_states -class FalconH1ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, IsHybrid): +class FalconH1ForCausalLM( + nn.Module, + HasInnerState, + SupportsLoRA, + SupportsPP, + IsHybrid, + SupportsMambaPrefixCaching, +): packed_modules_mapping = { "qkv_proj": ["q_proj", "k_proj", "v_proj"], "gate_up_proj": ["gate_proj", "up_proj"], diff --git a/vllm/model_executor/models/granitemoehybrid.py b/vllm/model_executor/models/granitemoehybrid.py index 1bb7f4e9b8023..bac64eec8c558 100644 --- a/vllm/model_executor/models/granitemoehybrid.py +++ b/vllm/model_executor/models/granitemoehybrid.py @@ -34,7 +34,14 @@ from vllm.sequence import IntermediateTensors from .granitemoe import GraniteMoeMoE from .granitemoeshared import GraniteMoeSharedMLP -from .interfaces import HasInnerState, IsHybrid, SupportsLoRA, SupportsPP, SupportsQuant +from .interfaces import ( + HasInnerState, + IsHybrid, + SupportsLoRA, + SupportsMambaPrefixCaching, + SupportsPP, + SupportsQuant, +) from .utils import ( AutoWeightsLoader, is_pp_missing_parameter, @@ -584,7 +591,13 @@ class GraniteMoeHybridModel(nn.Module): class GraniteMoeHybridForCausalLM( - nn.Module, HasInnerState, SupportsLoRA, SupportsPP, IsHybrid, SupportsQuant + nn.Module, + HasInnerState, + SupportsLoRA, + SupportsPP, + IsHybrid, + SupportsQuant, + SupportsMambaPrefixCaching, ): packed_modules_mapping = { "qkv_proj": [ diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py index 1bc5f5ae5419f..e133206c27a8b 100644 --- a/vllm/model_executor/models/interfaces.py +++ b/vllm/model_executor/models/interfaces.py @@ -697,6 +697,34 @@ def has_noops( return getattr(model, "has_noops", False) +@runtime_checkable +class SupportsMambaPrefixCaching(Protocol): + """The interface for models whose mamba layers support prefix caching. + + This is currently experimental. + """ + + supports_mamba_prefix_caching: ClassVar[Literal[True]] = True + + +@overload +def supports_mamba_prefix_caching( + model: object, +) -> TypeIs[SupportsMambaPrefixCaching]: ... + + +@overload +def supports_mamba_prefix_caching( + model: type[object], +) -> TypeIs[type[SupportsMambaPrefixCaching]]: ... + + +def supports_mamba_prefix_caching( + model: type[object] | object, +) -> TypeIs[type[SupportsMambaPrefixCaching]] | TypeIs[SupportsMambaPrefixCaching]: + return getattr(model, "supports_mamba_prefix_caching", False) + + @runtime_checkable class SupportsCrossEncoding(Protocol): """The interface required for all models that support cross encoding.""" diff --git a/vllm/model_executor/models/mamba2.py b/vllm/model_executor/models/mamba2.py index 5eb21b966e187..8ba8af66635b3 100644 --- a/vllm/model_executor/models/mamba2.py +++ b/vllm/model_executor/models/mamba2.py @@ -25,7 +25,11 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding, ) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.models.interfaces import HasInnerState, IsAttentionFree +from vllm.model_executor.models.interfaces import ( + HasInnerState, + IsAttentionFree, + SupportsMambaPrefixCaching, +) from vllm.sequence import IntermediateTensors from .utils import ( @@ -189,7 +193,9 @@ class Mamba2Model(nn.Module): return loaded_params -class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree): +class Mamba2ForCausalLM( + nn.Module, HasInnerState, IsAttentionFree, SupportsMambaPrefixCaching +): @classmethod def get_mamba_state_dtype_from_config( cls, diff --git a/vllm/model_executor/models/nemotron_h.py b/vllm/model_executor/models/nemotron_h.py index f31579e5cfa82..457d3910d0e57 100644 --- a/vllm/model_executor/models/nemotron_h.py +++ b/vllm/model_executor/models/nemotron_h.py @@ -62,6 +62,7 @@ from vllm.model_executor.models.interfaces import ( IsHybrid, MixtureOfExperts, SupportsLoRA, + SupportsMambaPrefixCaching, SupportsPP, SupportsQuant, ) @@ -695,6 +696,7 @@ class NemotronHForCausalLM( IsHybrid, SupportsQuant, MixtureOfExperts, + SupportsMambaPrefixCaching, ): hf_to_vllm_mapper = WeightsMapper( orig_to_new_prefix={"backbone": "model"}, diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index e8212ef6d72d8..0027954ac2771 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -39,6 +39,7 @@ from .interfaces import ( is_attention_free, is_hybrid, supports_cross_encoding, + supports_mamba_prefix_caching, supports_multimodal, supports_multimodal_encoder_tp_data, supports_multimodal_raw_input_only, @@ -496,6 +497,7 @@ class _ModelInfo: is_attention_free: bool is_hybrid: bool has_noops: bool + supports_mamba_prefix_caching: bool supports_transcription: bool supports_transcription_only: bool @@ -518,6 +520,7 @@ class _ModelInfo: has_inner_state=has_inner_state(model), is_attention_free=is_attention_free(model), is_hybrid=is_hybrid(model), + supports_mamba_prefix_caching=supports_mamba_prefix_caching(model), supports_transcription=supports_transcription(model), supports_transcription_only=( supports_transcription(model) and model.supports_transcription_only diff --git a/vllm/model_executor/models/zamba2.py b/vllm/model_executor/models/zamba2.py index 2610aa253b575..a6cfcf509776f 100644 --- a/vllm/model_executor/models/zamba2.py +++ b/vllm/model_executor/models/zamba2.py @@ -45,7 +45,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.sequence import IntermediateTensors -from .interfaces import HasInnerState, IsHybrid +from .interfaces import HasInnerState, IsHybrid, SupportsMambaPrefixCaching from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix @@ -824,7 +824,7 @@ class Zamba2Model(nn.Module): return loaded_params -class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid): +class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid, SupportsMambaPrefixCaching): """Zamba2 model with causal language modeling head. This class wraps the core Zamba2 model and adds: