mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-06 20:31:21 +08:00
[Hybrid] Added supports_mamba_prefix_caching Protocol (#27339)
Signed-off-by: asafg <39553475+Josephasafg@users.noreply.github.com>
This commit is contained in:
parent
f4e8154076
commit
9273754222
@ -1656,6 +1656,10 @@ class ModelConfig:
|
|||||||
def has_inner_state(self):
|
def has_inner_state(self):
|
||||||
return self._model_info.has_inner_state
|
return self._model_info.has_inner_state
|
||||||
|
|
||||||
|
@property
|
||||||
|
def supports_mamba_prefix_caching(self) -> bool:
|
||||||
|
return self._model_info.supports_mamba_prefix_caching
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def use_mla(self) -> bool:
|
def use_mla(self) -> bool:
|
||||||
return self.is_deepseek_mla and not envs.VLLM_MLA_DISABLE
|
return self.is_deepseek_mla and not envs.VLLM_MLA_DISABLE
|
||||||
|
|||||||
@ -37,7 +37,14 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
|||||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
|
|
||||||
from .interfaces import HasInnerState, IsHybrid, SupportsLoRA, SupportsPP, SupportsQuant
|
from .interfaces import (
|
||||||
|
HasInnerState,
|
||||||
|
IsHybrid,
|
||||||
|
SupportsLoRA,
|
||||||
|
SupportsMambaPrefixCaching,
|
||||||
|
SupportsPP,
|
||||||
|
SupportsQuant,
|
||||||
|
)
|
||||||
from .utils import (
|
from .utils import (
|
||||||
AutoWeightsLoader,
|
AutoWeightsLoader,
|
||||||
is_pp_missing_parameter,
|
is_pp_missing_parameter,
|
||||||
@ -394,7 +401,13 @@ class BambaModel(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class BambaForCausalLM(
|
class BambaForCausalLM(
|
||||||
nn.Module, HasInnerState, SupportsLoRA, SupportsPP, IsHybrid, SupportsQuant
|
nn.Module,
|
||||||
|
HasInnerState,
|
||||||
|
SupportsLoRA,
|
||||||
|
SupportsPP,
|
||||||
|
IsHybrid,
|
||||||
|
SupportsQuant,
|
||||||
|
SupportsMambaPrefixCaching,
|
||||||
):
|
):
|
||||||
packed_modules_mapping = {
|
packed_modules_mapping = {
|
||||||
"qkv_proj": [
|
"qkv_proj": [
|
||||||
|
|||||||
@ -295,17 +295,8 @@ class MambaModelConfig(VerifyAndUpdateConfig):
|
|||||||
# override by prefix caching logic later)
|
# override by prefix caching logic later)
|
||||||
cache_config.mamba_block_size = model_config.max_model_len
|
cache_config.mamba_block_size = model_config.max_model_len
|
||||||
|
|
||||||
# TODO(@tdoublep) find a better way to do this than whitelist
|
|
||||||
MAMBA2_MODELS = [
|
|
||||||
"BambaForCausalLM",
|
|
||||||
"FalconH1ForCausalLM",
|
|
||||||
"GraniteMoeHybridForCausalLM",
|
|
||||||
"Mamba2ForCausalLM",
|
|
||||||
"NemotronHForCausalLM",
|
|
||||||
"Zamba2ForCausalLM",
|
|
||||||
]
|
|
||||||
if cache_config.enable_prefix_caching:
|
if cache_config.enable_prefix_caching:
|
||||||
if model_config.architecture in MAMBA2_MODELS:
|
if model_config.supports_mamba_prefix_caching:
|
||||||
logger.info(
|
logger.info(
|
||||||
"Warning: Prefix caching is currently enabled. "
|
"Warning: Prefix caching is currently enabled. "
|
||||||
"Its support for Mamba2 layers is experimental. "
|
"Its support for Mamba2 layers is experimental. "
|
||||||
|
|||||||
@ -37,7 +37,13 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
|||||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
|
|
||||||
from .interfaces import HasInnerState, IsHybrid, SupportsLoRA, SupportsPP
|
from .interfaces import (
|
||||||
|
HasInnerState,
|
||||||
|
IsHybrid,
|
||||||
|
SupportsLoRA,
|
||||||
|
SupportsMambaPrefixCaching,
|
||||||
|
SupportsPP,
|
||||||
|
)
|
||||||
from .utils import (
|
from .utils import (
|
||||||
PPMissingLayer,
|
PPMissingLayer,
|
||||||
is_pp_missing_parameter,
|
is_pp_missing_parameter,
|
||||||
@ -495,7 +501,14 @@ class FalconH1Model(nn.Module):
|
|||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
class FalconH1ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, IsHybrid):
|
class FalconH1ForCausalLM(
|
||||||
|
nn.Module,
|
||||||
|
HasInnerState,
|
||||||
|
SupportsLoRA,
|
||||||
|
SupportsPP,
|
||||||
|
IsHybrid,
|
||||||
|
SupportsMambaPrefixCaching,
|
||||||
|
):
|
||||||
packed_modules_mapping = {
|
packed_modules_mapping = {
|
||||||
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
|
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
|
||||||
"gate_up_proj": ["gate_proj", "up_proj"],
|
"gate_up_proj": ["gate_proj", "up_proj"],
|
||||||
|
|||||||
@ -34,7 +34,14 @@ from vllm.sequence import IntermediateTensors
|
|||||||
|
|
||||||
from .granitemoe import GraniteMoeMoE
|
from .granitemoe import GraniteMoeMoE
|
||||||
from .granitemoeshared import GraniteMoeSharedMLP
|
from .granitemoeshared import GraniteMoeSharedMLP
|
||||||
from .interfaces import HasInnerState, IsHybrid, SupportsLoRA, SupportsPP, SupportsQuant
|
from .interfaces import (
|
||||||
|
HasInnerState,
|
||||||
|
IsHybrid,
|
||||||
|
SupportsLoRA,
|
||||||
|
SupportsMambaPrefixCaching,
|
||||||
|
SupportsPP,
|
||||||
|
SupportsQuant,
|
||||||
|
)
|
||||||
from .utils import (
|
from .utils import (
|
||||||
AutoWeightsLoader,
|
AutoWeightsLoader,
|
||||||
is_pp_missing_parameter,
|
is_pp_missing_parameter,
|
||||||
@ -584,7 +591,13 @@ class GraniteMoeHybridModel(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class GraniteMoeHybridForCausalLM(
|
class GraniteMoeHybridForCausalLM(
|
||||||
nn.Module, HasInnerState, SupportsLoRA, SupportsPP, IsHybrid, SupportsQuant
|
nn.Module,
|
||||||
|
HasInnerState,
|
||||||
|
SupportsLoRA,
|
||||||
|
SupportsPP,
|
||||||
|
IsHybrid,
|
||||||
|
SupportsQuant,
|
||||||
|
SupportsMambaPrefixCaching,
|
||||||
):
|
):
|
||||||
packed_modules_mapping = {
|
packed_modules_mapping = {
|
||||||
"qkv_proj": [
|
"qkv_proj": [
|
||||||
|
|||||||
@ -697,6 +697,34 @@ def has_noops(
|
|||||||
return getattr(model, "has_noops", False)
|
return getattr(model, "has_noops", False)
|
||||||
|
|
||||||
|
|
||||||
|
@runtime_checkable
|
||||||
|
class SupportsMambaPrefixCaching(Protocol):
|
||||||
|
"""The interface for models whose mamba layers support prefix caching.
|
||||||
|
|
||||||
|
This is currently experimental.
|
||||||
|
"""
|
||||||
|
|
||||||
|
supports_mamba_prefix_caching: ClassVar[Literal[True]] = True
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def supports_mamba_prefix_caching(
|
||||||
|
model: object,
|
||||||
|
) -> TypeIs[SupportsMambaPrefixCaching]: ...
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def supports_mamba_prefix_caching(
|
||||||
|
model: type[object],
|
||||||
|
) -> TypeIs[type[SupportsMambaPrefixCaching]]: ...
|
||||||
|
|
||||||
|
|
||||||
|
def supports_mamba_prefix_caching(
|
||||||
|
model: type[object] | object,
|
||||||
|
) -> TypeIs[type[SupportsMambaPrefixCaching]] | TypeIs[SupportsMambaPrefixCaching]:
|
||||||
|
return getattr(model, "supports_mamba_prefix_caching", False)
|
||||||
|
|
||||||
|
|
||||||
@runtime_checkable
|
@runtime_checkable
|
||||||
class SupportsCrossEncoding(Protocol):
|
class SupportsCrossEncoding(Protocol):
|
||||||
"""The interface required for all models that support cross encoding."""
|
"""The interface required for all models that support cross encoding."""
|
||||||
|
|||||||
@ -25,7 +25,11 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
|||||||
VocabParallelEmbedding,
|
VocabParallelEmbedding,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
from vllm.model_executor.models.interfaces import HasInnerState, IsAttentionFree
|
from vllm.model_executor.models.interfaces import (
|
||||||
|
HasInnerState,
|
||||||
|
IsAttentionFree,
|
||||||
|
SupportsMambaPrefixCaching,
|
||||||
|
)
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
|
|
||||||
from .utils import (
|
from .utils import (
|
||||||
@ -189,7 +193,9 @@ class Mamba2Model(nn.Module):
|
|||||||
return loaded_params
|
return loaded_params
|
||||||
|
|
||||||
|
|
||||||
class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
|
class Mamba2ForCausalLM(
|
||||||
|
nn.Module, HasInnerState, IsAttentionFree, SupportsMambaPrefixCaching
|
||||||
|
):
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_mamba_state_dtype_from_config(
|
def get_mamba_state_dtype_from_config(
|
||||||
cls,
|
cls,
|
||||||
|
|||||||
@ -62,6 +62,7 @@ from vllm.model_executor.models.interfaces import (
|
|||||||
IsHybrid,
|
IsHybrid,
|
||||||
MixtureOfExperts,
|
MixtureOfExperts,
|
||||||
SupportsLoRA,
|
SupportsLoRA,
|
||||||
|
SupportsMambaPrefixCaching,
|
||||||
SupportsPP,
|
SupportsPP,
|
||||||
SupportsQuant,
|
SupportsQuant,
|
||||||
)
|
)
|
||||||
@ -695,6 +696,7 @@ class NemotronHForCausalLM(
|
|||||||
IsHybrid,
|
IsHybrid,
|
||||||
SupportsQuant,
|
SupportsQuant,
|
||||||
MixtureOfExperts,
|
MixtureOfExperts,
|
||||||
|
SupportsMambaPrefixCaching,
|
||||||
):
|
):
|
||||||
hf_to_vllm_mapper = WeightsMapper(
|
hf_to_vllm_mapper = WeightsMapper(
|
||||||
orig_to_new_prefix={"backbone": "model"},
|
orig_to_new_prefix={"backbone": "model"},
|
||||||
|
|||||||
@ -39,6 +39,7 @@ from .interfaces import (
|
|||||||
is_attention_free,
|
is_attention_free,
|
||||||
is_hybrid,
|
is_hybrid,
|
||||||
supports_cross_encoding,
|
supports_cross_encoding,
|
||||||
|
supports_mamba_prefix_caching,
|
||||||
supports_multimodal,
|
supports_multimodal,
|
||||||
supports_multimodal_encoder_tp_data,
|
supports_multimodal_encoder_tp_data,
|
||||||
supports_multimodal_raw_input_only,
|
supports_multimodal_raw_input_only,
|
||||||
@ -496,6 +497,7 @@ class _ModelInfo:
|
|||||||
is_attention_free: bool
|
is_attention_free: bool
|
||||||
is_hybrid: bool
|
is_hybrid: bool
|
||||||
has_noops: bool
|
has_noops: bool
|
||||||
|
supports_mamba_prefix_caching: bool
|
||||||
supports_transcription: bool
|
supports_transcription: bool
|
||||||
supports_transcription_only: bool
|
supports_transcription_only: bool
|
||||||
|
|
||||||
@ -518,6 +520,7 @@ class _ModelInfo:
|
|||||||
has_inner_state=has_inner_state(model),
|
has_inner_state=has_inner_state(model),
|
||||||
is_attention_free=is_attention_free(model),
|
is_attention_free=is_attention_free(model),
|
||||||
is_hybrid=is_hybrid(model),
|
is_hybrid=is_hybrid(model),
|
||||||
|
supports_mamba_prefix_caching=supports_mamba_prefix_caching(model),
|
||||||
supports_transcription=supports_transcription(model),
|
supports_transcription=supports_transcription(model),
|
||||||
supports_transcription_only=(
|
supports_transcription_only=(
|
||||||
supports_transcription(model) and model.supports_transcription_only
|
supports_transcription(model) and model.supports_transcription_only
|
||||||
|
|||||||
@ -45,7 +45,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
|||||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
|
|
||||||
from .interfaces import HasInnerState, IsHybrid
|
from .interfaces import HasInnerState, IsHybrid, SupportsMambaPrefixCaching
|
||||||
from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix
|
from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix
|
||||||
|
|
||||||
|
|
||||||
@ -824,7 +824,7 @@ class Zamba2Model(nn.Module):
|
|||||||
return loaded_params
|
return loaded_params
|
||||||
|
|
||||||
|
|
||||||
class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid):
|
class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid, SupportsMambaPrefixCaching):
|
||||||
"""Zamba2 model with causal language modeling head.
|
"""Zamba2 model with causal language modeling head.
|
||||||
|
|
||||||
This class wraps the core Zamba2 model and adds:
|
This class wraps the core Zamba2 model and adds:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user