diff --git a/vllm/model_executor/models/phi4flash.py b/vllm/model_executor/models/phi4flash.py index 493a4192d35a..fcdfcb7bc160 100644 --- a/vllm/model_executor/models/phi4flash.py +++ b/vllm/model_executor/models/phi4flash.py @@ -650,8 +650,12 @@ class Phi4FlashForCausalLM(nn.Module, HasInnerState, IsHybrid, SupportsV0Only): num_mamba_layers = self.config.num_hidden_layers \ // 2 // self.config.mb_per_layer + 1 self.mamba_cache = MambaCacheManager( - self.vllm_config, self.lm_head.weight.dtype, num_mamba_layers, - *self._get_mamba_cache_shape()) + self.vllm_config, + num_mamba_layers, + *self._get_mamba_cache_shape(), + self.lm_head.weight.dtype, + self.lm_head.weight.dtype, + ) mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) attn_metadata = get_forward_context().attn_metadata diff --git a/vllm/model_executor/models/plamo2.py b/vllm/model_executor/models/plamo2.py index 8b1df66f0280..e5034b536266 100644 --- a/vllm/model_executor/models/plamo2.py +++ b/vllm/model_executor/models/plamo2.py @@ -767,8 +767,12 @@ class Plamo2ForCausalLM(Plamo2PreTrainedModel, HasInnerState, SupportsPP, self.vllm_config.parallel_config, LayerBlockType.mamba) self.mamba_cache = MambaCacheManager( - self.vllm_config, self.lm_head.weight.dtype, num_mamba_layers, - *self._get_mamba_cache_shape()) + self.vllm_config, + num_mamba_layers, + *self._get_mamba_cache_shape(), + self.lm_head.weight.dtype, + self.lm_head.weight.dtype, + ) mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)