[Hybrid] Added supports_mamba_prefix_caching Protocol (#27339)

Signed-off-by: asafg <39553475+Josephasafg@users.noreply.github.com>
This commit is contained in:
Asaf Joseph Gardin 2025-10-27 15:05:20 +02:00 committed by GitHub
parent f4e8154076
commit 9273754222
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 93 additions and 20 deletions

View File

@ -1656,6 +1656,10 @@ class ModelConfig:
def has_inner_state(self): def has_inner_state(self):
return self._model_info.has_inner_state return self._model_info.has_inner_state
@property
def supports_mamba_prefix_caching(self) -> bool:
return self._model_info.supports_mamba_prefix_caching
@property @property
def use_mla(self) -> bool: def use_mla(self) -> bool:
return self.is_deepseek_mla and not envs.VLLM_MLA_DISABLE return self.is_deepseek_mla and not envs.VLLM_MLA_DISABLE

View File

@ -37,7 +37,14 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import HasInnerState, IsHybrid, SupportsLoRA, SupportsPP, SupportsQuant from .interfaces import (
HasInnerState,
IsHybrid,
SupportsLoRA,
SupportsMambaPrefixCaching,
SupportsPP,
SupportsQuant,
)
from .utils import ( from .utils import (
AutoWeightsLoader, AutoWeightsLoader,
is_pp_missing_parameter, is_pp_missing_parameter,
@ -394,7 +401,13 @@ class BambaModel(nn.Module):
class BambaForCausalLM( class BambaForCausalLM(
nn.Module, HasInnerState, SupportsLoRA, SupportsPP, IsHybrid, SupportsQuant nn.Module,
HasInnerState,
SupportsLoRA,
SupportsPP,
IsHybrid,
SupportsQuant,
SupportsMambaPrefixCaching,
): ):
packed_modules_mapping = { packed_modules_mapping = {
"qkv_proj": [ "qkv_proj": [

View File

@ -295,17 +295,8 @@ class MambaModelConfig(VerifyAndUpdateConfig):
# override by prefix caching logic later) # override by prefix caching logic later)
cache_config.mamba_block_size = model_config.max_model_len cache_config.mamba_block_size = model_config.max_model_len
# TODO(@tdoublep) find a better way to do this than whitelist
MAMBA2_MODELS = [
"BambaForCausalLM",
"FalconH1ForCausalLM",
"GraniteMoeHybridForCausalLM",
"Mamba2ForCausalLM",
"NemotronHForCausalLM",
"Zamba2ForCausalLM",
]
if cache_config.enable_prefix_caching: if cache_config.enable_prefix_caching:
if model_config.architecture in MAMBA2_MODELS: if model_config.supports_mamba_prefix_caching:
logger.info( logger.info(
"Warning: Prefix caching is currently enabled. " "Warning: Prefix caching is currently enabled. "
"Its support for Mamba2 layers is experimental. " "Its support for Mamba2 layers is experimental. "

View File

@ -37,7 +37,13 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import HasInnerState, IsHybrid, SupportsLoRA, SupportsPP from .interfaces import (
HasInnerState,
IsHybrid,
SupportsLoRA,
SupportsMambaPrefixCaching,
SupportsPP,
)
from .utils import ( from .utils import (
PPMissingLayer, PPMissingLayer,
is_pp_missing_parameter, is_pp_missing_parameter,
@ -495,7 +501,14 @@ class FalconH1Model(nn.Module):
return hidden_states return hidden_states
class FalconH1ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, IsHybrid): class FalconH1ForCausalLM(
nn.Module,
HasInnerState,
SupportsLoRA,
SupportsPP,
IsHybrid,
SupportsMambaPrefixCaching,
):
packed_modules_mapping = { packed_modules_mapping = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"], "qkv_proj": ["q_proj", "k_proj", "v_proj"],
"gate_up_proj": ["gate_proj", "up_proj"], "gate_up_proj": ["gate_proj", "up_proj"],

View File

@ -34,7 +34,14 @@ from vllm.sequence import IntermediateTensors
from .granitemoe import GraniteMoeMoE from .granitemoe import GraniteMoeMoE
from .granitemoeshared import GraniteMoeSharedMLP from .granitemoeshared import GraniteMoeSharedMLP
from .interfaces import HasInnerState, IsHybrid, SupportsLoRA, SupportsPP, SupportsQuant from .interfaces import (
HasInnerState,
IsHybrid,
SupportsLoRA,
SupportsMambaPrefixCaching,
SupportsPP,
SupportsQuant,
)
from .utils import ( from .utils import (
AutoWeightsLoader, AutoWeightsLoader,
is_pp_missing_parameter, is_pp_missing_parameter,
@ -584,7 +591,13 @@ class GraniteMoeHybridModel(nn.Module):
class GraniteMoeHybridForCausalLM( class GraniteMoeHybridForCausalLM(
nn.Module, HasInnerState, SupportsLoRA, SupportsPP, IsHybrid, SupportsQuant nn.Module,
HasInnerState,
SupportsLoRA,
SupportsPP,
IsHybrid,
SupportsQuant,
SupportsMambaPrefixCaching,
): ):
packed_modules_mapping = { packed_modules_mapping = {
"qkv_proj": [ "qkv_proj": [

View File

@ -697,6 +697,34 @@ def has_noops(
return getattr(model, "has_noops", False) return getattr(model, "has_noops", False)
@runtime_checkable
class SupportsMambaPrefixCaching(Protocol):
"""The interface for models whose mamba layers support prefix caching.
This is currently experimental.
"""
supports_mamba_prefix_caching: ClassVar[Literal[True]] = True
@overload
def supports_mamba_prefix_caching(
model: object,
) -> TypeIs[SupportsMambaPrefixCaching]: ...
@overload
def supports_mamba_prefix_caching(
model: type[object],
) -> TypeIs[type[SupportsMambaPrefixCaching]]: ...
def supports_mamba_prefix_caching(
model: type[object] | object,
) -> TypeIs[type[SupportsMambaPrefixCaching]] | TypeIs[SupportsMambaPrefixCaching]:
return getattr(model, "supports_mamba_prefix_caching", False)
@runtime_checkable @runtime_checkable
class SupportsCrossEncoding(Protocol): class SupportsCrossEncoding(Protocol):
"""The interface required for all models that support cross encoding.""" """The interface required for all models that support cross encoding."""

View File

@ -25,7 +25,11 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, VocabParallelEmbedding,
) )
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.interfaces import HasInnerState, IsAttentionFree from vllm.model_executor.models.interfaces import (
HasInnerState,
IsAttentionFree,
SupportsMambaPrefixCaching,
)
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .utils import ( from .utils import (
@ -189,7 +193,9 @@ class Mamba2Model(nn.Module):
return loaded_params return loaded_params
class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree): class Mamba2ForCausalLM(
nn.Module, HasInnerState, IsAttentionFree, SupportsMambaPrefixCaching
):
@classmethod @classmethod
def get_mamba_state_dtype_from_config( def get_mamba_state_dtype_from_config(
cls, cls,

View File

@ -62,6 +62,7 @@ from vllm.model_executor.models.interfaces import (
IsHybrid, IsHybrid,
MixtureOfExperts, MixtureOfExperts,
SupportsLoRA, SupportsLoRA,
SupportsMambaPrefixCaching,
SupportsPP, SupportsPP,
SupportsQuant, SupportsQuant,
) )
@ -695,6 +696,7 @@ class NemotronHForCausalLM(
IsHybrid, IsHybrid,
SupportsQuant, SupportsQuant,
MixtureOfExperts, MixtureOfExperts,
SupportsMambaPrefixCaching,
): ):
hf_to_vllm_mapper = WeightsMapper( hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={"backbone": "model"}, orig_to_new_prefix={"backbone": "model"},

View File

@ -39,6 +39,7 @@ from .interfaces import (
is_attention_free, is_attention_free,
is_hybrid, is_hybrid,
supports_cross_encoding, supports_cross_encoding,
supports_mamba_prefix_caching,
supports_multimodal, supports_multimodal,
supports_multimodal_encoder_tp_data, supports_multimodal_encoder_tp_data,
supports_multimodal_raw_input_only, supports_multimodal_raw_input_only,
@ -496,6 +497,7 @@ class _ModelInfo:
is_attention_free: bool is_attention_free: bool
is_hybrid: bool is_hybrid: bool
has_noops: bool has_noops: bool
supports_mamba_prefix_caching: bool
supports_transcription: bool supports_transcription: bool
supports_transcription_only: bool supports_transcription_only: bool
@ -518,6 +520,7 @@ class _ModelInfo:
has_inner_state=has_inner_state(model), has_inner_state=has_inner_state(model),
is_attention_free=is_attention_free(model), is_attention_free=is_attention_free(model),
is_hybrid=is_hybrid(model), is_hybrid=is_hybrid(model),
supports_mamba_prefix_caching=supports_mamba_prefix_caching(model),
supports_transcription=supports_transcription(model), supports_transcription=supports_transcription(model),
supports_transcription_only=( supports_transcription_only=(
supports_transcription(model) and model.supports_transcription_only supports_transcription(model) and model.supports_transcription_only

View File

@ -45,7 +45,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import HasInnerState, IsHybrid from .interfaces import HasInnerState, IsHybrid, SupportsMambaPrefixCaching
from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix
@ -824,7 +824,7 @@ class Zamba2Model(nn.Module):
return loaded_params return loaded_params
class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid): class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid, SupportsMambaPrefixCaching):
"""Zamba2 model with causal language modeling head. """Zamba2 model with causal language modeling head.
This class wraps the core Zamba2 model and adds: This class wraps the core Zamba2 model and adds: