[Hybrid] Added supports_mamba_prefix_caching Protocol (#27339)

Signed-off-by: asafg <39553475+Josephasafg@users.noreply.github.com>
This commit is contained in:
Asaf Joseph Gardin 2025-10-27 15:05:20 +02:00 committed by GitHub
parent f4e8154076
commit 9273754222
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 93 additions and 20 deletions

View File

@ -1656,6 +1656,10 @@ class ModelConfig:
def has_inner_state(self):
return self._model_info.has_inner_state
@property
def supports_mamba_prefix_caching(self) -> bool:
return self._model_info.supports_mamba_prefix_caching
@property
def use_mla(self) -> bool:
return self.is_deepseek_mla and not envs.VLLM_MLA_DISABLE

View File

@ -37,7 +37,14 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.sequence import IntermediateTensors
from .interfaces import HasInnerState, IsHybrid, SupportsLoRA, SupportsPP, SupportsQuant
from .interfaces import (
HasInnerState,
IsHybrid,
SupportsLoRA,
SupportsMambaPrefixCaching,
SupportsPP,
SupportsQuant,
)
from .utils import (
AutoWeightsLoader,
is_pp_missing_parameter,
@ -394,7 +401,13 @@ class BambaModel(nn.Module):
class BambaForCausalLM(
nn.Module, HasInnerState, SupportsLoRA, SupportsPP, IsHybrid, SupportsQuant
nn.Module,
HasInnerState,
SupportsLoRA,
SupportsPP,
IsHybrid,
SupportsQuant,
SupportsMambaPrefixCaching,
):
packed_modules_mapping = {
"qkv_proj": [

View File

@ -295,17 +295,8 @@ class MambaModelConfig(VerifyAndUpdateConfig):
# override by prefix caching logic later)
cache_config.mamba_block_size = model_config.max_model_len
# TODO(@tdoublep) find a better way to do this than whitelist
MAMBA2_MODELS = [
"BambaForCausalLM",
"FalconH1ForCausalLM",
"GraniteMoeHybridForCausalLM",
"Mamba2ForCausalLM",
"NemotronHForCausalLM",
"Zamba2ForCausalLM",
]
if cache_config.enable_prefix_caching:
if model_config.architecture in MAMBA2_MODELS:
if model_config.supports_mamba_prefix_caching:
logger.info(
"Warning: Prefix caching is currently enabled. "
"Its support for Mamba2 layers is experimental. "

View File

@ -37,7 +37,13 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.sequence import IntermediateTensors
from .interfaces import HasInnerState, IsHybrid, SupportsLoRA, SupportsPP
from .interfaces import (
HasInnerState,
IsHybrid,
SupportsLoRA,
SupportsMambaPrefixCaching,
SupportsPP,
)
from .utils import (
PPMissingLayer,
is_pp_missing_parameter,
@ -495,7 +501,14 @@ class FalconH1Model(nn.Module):
return hidden_states
class FalconH1ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, IsHybrid):
class FalconH1ForCausalLM(
nn.Module,
HasInnerState,
SupportsLoRA,
SupportsPP,
IsHybrid,
SupportsMambaPrefixCaching,
):
packed_modules_mapping = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
"gate_up_proj": ["gate_proj", "up_proj"],

View File

@ -34,7 +34,14 @@ from vllm.sequence import IntermediateTensors
from .granitemoe import GraniteMoeMoE
from .granitemoeshared import GraniteMoeSharedMLP
from .interfaces import HasInnerState, IsHybrid, SupportsLoRA, SupportsPP, SupportsQuant
from .interfaces import (
HasInnerState,
IsHybrid,
SupportsLoRA,
SupportsMambaPrefixCaching,
SupportsPP,
SupportsQuant,
)
from .utils import (
AutoWeightsLoader,
is_pp_missing_parameter,
@ -584,7 +591,13 @@ class GraniteMoeHybridModel(nn.Module):
class GraniteMoeHybridForCausalLM(
nn.Module, HasInnerState, SupportsLoRA, SupportsPP, IsHybrid, SupportsQuant
nn.Module,
HasInnerState,
SupportsLoRA,
SupportsPP,
IsHybrid,
SupportsQuant,
SupportsMambaPrefixCaching,
):
packed_modules_mapping = {
"qkv_proj": [

View File

@ -697,6 +697,34 @@ def has_noops(
return getattr(model, "has_noops", False)
@runtime_checkable
class SupportsMambaPrefixCaching(Protocol):
"""The interface for models whose mamba layers support prefix caching.
This is currently experimental.
"""
supports_mamba_prefix_caching: ClassVar[Literal[True]] = True
@overload
def supports_mamba_prefix_caching(
model: object,
) -> TypeIs[SupportsMambaPrefixCaching]: ...
@overload
def supports_mamba_prefix_caching(
model: type[object],
) -> TypeIs[type[SupportsMambaPrefixCaching]]: ...
def supports_mamba_prefix_caching(
model: type[object] | object,
) -> TypeIs[type[SupportsMambaPrefixCaching]] | TypeIs[SupportsMambaPrefixCaching]:
return getattr(model, "supports_mamba_prefix_caching", False)
@runtime_checkable
class SupportsCrossEncoding(Protocol):
"""The interface required for all models that support cross encoding."""

View File

@ -25,7 +25,11 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding,
)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.interfaces import HasInnerState, IsAttentionFree
from vllm.model_executor.models.interfaces import (
HasInnerState,
IsAttentionFree,
SupportsMambaPrefixCaching,
)
from vllm.sequence import IntermediateTensors
from .utils import (
@ -189,7 +193,9 @@ class Mamba2Model(nn.Module):
return loaded_params
class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
class Mamba2ForCausalLM(
nn.Module, HasInnerState, IsAttentionFree, SupportsMambaPrefixCaching
):
@classmethod
def get_mamba_state_dtype_from_config(
cls,

View File

@ -62,6 +62,7 @@ from vllm.model_executor.models.interfaces import (
IsHybrid,
MixtureOfExperts,
SupportsLoRA,
SupportsMambaPrefixCaching,
SupportsPP,
SupportsQuant,
)
@ -695,6 +696,7 @@ class NemotronHForCausalLM(
IsHybrid,
SupportsQuant,
MixtureOfExperts,
SupportsMambaPrefixCaching,
):
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={"backbone": "model"},

View File

@ -39,6 +39,7 @@ from .interfaces import (
is_attention_free,
is_hybrid,
supports_cross_encoding,
supports_mamba_prefix_caching,
supports_multimodal,
supports_multimodal_encoder_tp_data,
supports_multimodal_raw_input_only,
@ -496,6 +497,7 @@ class _ModelInfo:
is_attention_free: bool
is_hybrid: bool
has_noops: bool
supports_mamba_prefix_caching: bool
supports_transcription: bool
supports_transcription_only: bool
@ -518,6 +520,7 @@ class _ModelInfo:
has_inner_state=has_inner_state(model),
is_attention_free=is_attention_free(model),
is_hybrid=is_hybrid(model),
supports_mamba_prefix_caching=supports_mamba_prefix_caching(model),
supports_transcription=supports_transcription(model),
supports_transcription_only=(
supports_transcription(model) and model.supports_transcription_only

View File

@ -45,7 +45,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.sequence import IntermediateTensors
from .interfaces import HasInnerState, IsHybrid
from .interfaces import HasInnerState, IsHybrid, SupportsMambaPrefixCaching
from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix
@ -824,7 +824,7 @@ class Zamba2Model(nn.Module):
return loaded_params
class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid):
class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid, SupportsMambaPrefixCaching):
"""Zamba2 model with causal language modeling head.
This class wraps the core Zamba2 model and adds: