mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-24 08:44:28 +08:00
[Hybrid] Added supports_mamba_prefix_caching Protocol (#27339)
Signed-off-by: asafg <39553475+Josephasafg@users.noreply.github.com>
This commit is contained in:
parent
f4e8154076
commit
9273754222
@ -1656,6 +1656,10 @@ class ModelConfig:
|
||||
def has_inner_state(self):
|
||||
return self._model_info.has_inner_state
|
||||
|
||||
@property
|
||||
def supports_mamba_prefix_caching(self) -> bool:
|
||||
return self._model_info.supports_mamba_prefix_caching
|
||||
|
||||
@property
|
||||
def use_mla(self) -> bool:
|
||||
return self.is_deepseek_mla and not envs.VLLM_MLA_DISABLE
|
||||
|
||||
@ -37,7 +37,14 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import HasInnerState, IsHybrid, SupportsLoRA, SupportsPP, SupportsQuant
|
||||
from .interfaces import (
|
||||
HasInnerState,
|
||||
IsHybrid,
|
||||
SupportsLoRA,
|
||||
SupportsMambaPrefixCaching,
|
||||
SupportsPP,
|
||||
SupportsQuant,
|
||||
)
|
||||
from .utils import (
|
||||
AutoWeightsLoader,
|
||||
is_pp_missing_parameter,
|
||||
@ -394,7 +401,13 @@ class BambaModel(nn.Module):
|
||||
|
||||
|
||||
class BambaForCausalLM(
|
||||
nn.Module, HasInnerState, SupportsLoRA, SupportsPP, IsHybrid, SupportsQuant
|
||||
nn.Module,
|
||||
HasInnerState,
|
||||
SupportsLoRA,
|
||||
SupportsPP,
|
||||
IsHybrid,
|
||||
SupportsQuant,
|
||||
SupportsMambaPrefixCaching,
|
||||
):
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
|
||||
@ -295,17 +295,8 @@ class MambaModelConfig(VerifyAndUpdateConfig):
|
||||
# override by prefix caching logic later)
|
||||
cache_config.mamba_block_size = model_config.max_model_len
|
||||
|
||||
# TODO(@tdoublep) find a better way to do this than whitelist
|
||||
MAMBA2_MODELS = [
|
||||
"BambaForCausalLM",
|
||||
"FalconH1ForCausalLM",
|
||||
"GraniteMoeHybridForCausalLM",
|
||||
"Mamba2ForCausalLM",
|
||||
"NemotronHForCausalLM",
|
||||
"Zamba2ForCausalLM",
|
||||
]
|
||||
if cache_config.enable_prefix_caching:
|
||||
if model_config.architecture in MAMBA2_MODELS:
|
||||
if model_config.supports_mamba_prefix_caching:
|
||||
logger.info(
|
||||
"Warning: Prefix caching is currently enabled. "
|
||||
"Its support for Mamba2 layers is experimental. "
|
||||
|
||||
@ -37,7 +37,13 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import HasInnerState, IsHybrid, SupportsLoRA, SupportsPP
|
||||
from .interfaces import (
|
||||
HasInnerState,
|
||||
IsHybrid,
|
||||
SupportsLoRA,
|
||||
SupportsMambaPrefixCaching,
|
||||
SupportsPP,
|
||||
)
|
||||
from .utils import (
|
||||
PPMissingLayer,
|
||||
is_pp_missing_parameter,
|
||||
@ -495,7 +501,14 @@ class FalconH1Model(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FalconH1ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, IsHybrid):
|
||||
class FalconH1ForCausalLM(
|
||||
nn.Module,
|
||||
HasInnerState,
|
||||
SupportsLoRA,
|
||||
SupportsPP,
|
||||
IsHybrid,
|
||||
SupportsMambaPrefixCaching,
|
||||
):
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
|
||||
"gate_up_proj": ["gate_proj", "up_proj"],
|
||||
|
||||
@ -34,7 +34,14 @@ from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .granitemoe import GraniteMoeMoE
|
||||
from .granitemoeshared import GraniteMoeSharedMLP
|
||||
from .interfaces import HasInnerState, IsHybrid, SupportsLoRA, SupportsPP, SupportsQuant
|
||||
from .interfaces import (
|
||||
HasInnerState,
|
||||
IsHybrid,
|
||||
SupportsLoRA,
|
||||
SupportsMambaPrefixCaching,
|
||||
SupportsPP,
|
||||
SupportsQuant,
|
||||
)
|
||||
from .utils import (
|
||||
AutoWeightsLoader,
|
||||
is_pp_missing_parameter,
|
||||
@ -584,7 +591,13 @@ class GraniteMoeHybridModel(nn.Module):
|
||||
|
||||
|
||||
class GraniteMoeHybridForCausalLM(
|
||||
nn.Module, HasInnerState, SupportsLoRA, SupportsPP, IsHybrid, SupportsQuant
|
||||
nn.Module,
|
||||
HasInnerState,
|
||||
SupportsLoRA,
|
||||
SupportsPP,
|
||||
IsHybrid,
|
||||
SupportsQuant,
|
||||
SupportsMambaPrefixCaching,
|
||||
):
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
|
||||
@ -697,6 +697,34 @@ def has_noops(
|
||||
return getattr(model, "has_noops", False)
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class SupportsMambaPrefixCaching(Protocol):
|
||||
"""The interface for models whose mamba layers support prefix caching.
|
||||
|
||||
This is currently experimental.
|
||||
"""
|
||||
|
||||
supports_mamba_prefix_caching: ClassVar[Literal[True]] = True
|
||||
|
||||
|
||||
@overload
|
||||
def supports_mamba_prefix_caching(
|
||||
model: object,
|
||||
) -> TypeIs[SupportsMambaPrefixCaching]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def supports_mamba_prefix_caching(
|
||||
model: type[object],
|
||||
) -> TypeIs[type[SupportsMambaPrefixCaching]]: ...
|
||||
|
||||
|
||||
def supports_mamba_prefix_caching(
|
||||
model: type[object] | object,
|
||||
) -> TypeIs[type[SupportsMambaPrefixCaching]] | TypeIs[SupportsMambaPrefixCaching]:
|
||||
return getattr(model, "supports_mamba_prefix_caching", False)
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class SupportsCrossEncoding(Protocol):
|
||||
"""The interface required for all models that support cross encoding."""
|
||||
|
||||
@ -25,7 +25,11 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding,
|
||||
)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.interfaces import HasInnerState, IsAttentionFree
|
||||
from vllm.model_executor.models.interfaces import (
|
||||
HasInnerState,
|
||||
IsAttentionFree,
|
||||
SupportsMambaPrefixCaching,
|
||||
)
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .utils import (
|
||||
@ -189,7 +193,9 @@ class Mamba2Model(nn.Module):
|
||||
return loaded_params
|
||||
|
||||
|
||||
class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
|
||||
class Mamba2ForCausalLM(
|
||||
nn.Module, HasInnerState, IsAttentionFree, SupportsMambaPrefixCaching
|
||||
):
|
||||
@classmethod
|
||||
def get_mamba_state_dtype_from_config(
|
||||
cls,
|
||||
|
||||
@ -62,6 +62,7 @@ from vllm.model_executor.models.interfaces import (
|
||||
IsHybrid,
|
||||
MixtureOfExperts,
|
||||
SupportsLoRA,
|
||||
SupportsMambaPrefixCaching,
|
||||
SupportsPP,
|
||||
SupportsQuant,
|
||||
)
|
||||
@ -695,6 +696,7 @@ class NemotronHForCausalLM(
|
||||
IsHybrid,
|
||||
SupportsQuant,
|
||||
MixtureOfExperts,
|
||||
SupportsMambaPrefixCaching,
|
||||
):
|
||||
hf_to_vllm_mapper = WeightsMapper(
|
||||
orig_to_new_prefix={"backbone": "model"},
|
||||
|
||||
@ -39,6 +39,7 @@ from .interfaces import (
|
||||
is_attention_free,
|
||||
is_hybrid,
|
||||
supports_cross_encoding,
|
||||
supports_mamba_prefix_caching,
|
||||
supports_multimodal,
|
||||
supports_multimodal_encoder_tp_data,
|
||||
supports_multimodal_raw_input_only,
|
||||
@ -496,6 +497,7 @@ class _ModelInfo:
|
||||
is_attention_free: bool
|
||||
is_hybrid: bool
|
||||
has_noops: bool
|
||||
supports_mamba_prefix_caching: bool
|
||||
supports_transcription: bool
|
||||
supports_transcription_only: bool
|
||||
|
||||
@ -518,6 +520,7 @@ class _ModelInfo:
|
||||
has_inner_state=has_inner_state(model),
|
||||
is_attention_free=is_attention_free(model),
|
||||
is_hybrid=is_hybrid(model),
|
||||
supports_mamba_prefix_caching=supports_mamba_prefix_caching(model),
|
||||
supports_transcription=supports_transcription(model),
|
||||
supports_transcription_only=(
|
||||
supports_transcription(model) and model.supports_transcription_only
|
||||
|
||||
@ -45,7 +45,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import HasInnerState, IsHybrid
|
||||
from .interfaces import HasInnerState, IsHybrid, SupportsMambaPrefixCaching
|
||||
from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix
|
||||
|
||||
|
||||
@ -824,7 +824,7 @@ class Zamba2Model(nn.Module):
|
||||
return loaded_params
|
||||
|
||||
|
||||
class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid):
|
||||
class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid, SupportsMambaPrefixCaching):
|
||||
"""Zamba2 model with causal language modeling head.
|
||||
|
||||
This class wraps the core Zamba2 model and adds:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user