From f5d412bafbd9d4700ff57cb6a2d5220cf2b7637e Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Sat, 16 Aug 2025 00:55:26 +0200 Subject: [PATCH] [BugFix] Fix regression caused by mamba state dtype PR (#22998) Signed-off-by: Thomas Parnell --- vllm/model_executor/models/phi4flash.py | 8 ++++++-- vllm/model_executor/models/plamo2.py | 8 ++++++-- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/models/phi4flash.py b/vllm/model_executor/models/phi4flash.py index 493a4192d35a..fcdfcb7bc160 100644 --- a/vllm/model_executor/models/phi4flash.py +++ b/vllm/model_executor/models/phi4flash.py @@ -650,8 +650,12 @@ class Phi4FlashForCausalLM(nn.Module, HasInnerState, IsHybrid, SupportsV0Only): num_mamba_layers = self.config.num_hidden_layers \ // 2 // self.config.mb_per_layer + 1 self.mamba_cache = MambaCacheManager( - self.vllm_config, self.lm_head.weight.dtype, num_mamba_layers, - *self._get_mamba_cache_shape()) + self.vllm_config, + num_mamba_layers, + *self._get_mamba_cache_shape(), + self.lm_head.weight.dtype, + self.lm_head.weight.dtype, + ) mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) attn_metadata = get_forward_context().attn_metadata diff --git a/vllm/model_executor/models/plamo2.py b/vllm/model_executor/models/plamo2.py index 8b1df66f0280..e5034b536266 100644 --- a/vllm/model_executor/models/plamo2.py +++ b/vllm/model_executor/models/plamo2.py @@ -767,8 +767,12 @@ class Plamo2ForCausalLM(Plamo2PreTrainedModel, HasInnerState, SupportsPP, self.vllm_config.parallel_config, LayerBlockType.mamba) self.mamba_cache = MambaCacheManager( - self.vllm_config, self.lm_head.weight.dtype, num_mamba_layers, - *self._get_mamba_cache_shape()) + self.vllm_config, + num_mamba_layers, + *self._get_mamba_cache_shape(), + self.lm_head.weight.dtype, + self.lm_head.weight.dtype, + ) mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)