[BugFix] Fix regression caused by mamba state dtype PR (#22998)

Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
This commit is contained in:
Thomas Parnell 2025-08-16 00:55:26 +02:00 committed by GitHub
parent 177e55e3bd
commit f5d412bafb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 12 additions and 4 deletions

View File

@ -650,8 +650,12 @@ class Phi4FlashForCausalLM(nn.Module, HasInnerState, IsHybrid, SupportsV0Only):
num_mamba_layers = self.config.num_hidden_layers \
// 2 // self.config.mb_per_layer + 1
self.mamba_cache = MambaCacheManager(
self.vllm_config, self.lm_head.weight.dtype, num_mamba_layers,
*self._get_mamba_cache_shape())
self.vllm_config,
num_mamba_layers,
*self._get_mamba_cache_shape(),
self.lm_head.weight.dtype,
self.lm_head.weight.dtype,
)
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
attn_metadata = get_forward_context().attn_metadata

View File

@ -767,8 +767,12 @@ class Plamo2ForCausalLM(Plamo2PreTrainedModel, HasInnerState, SupportsPP,
self.vllm_config.parallel_config, LayerBlockType.mamba)
self.mamba_cache = MambaCacheManager(
self.vllm_config, self.lm_head.weight.dtype, num_mamba_layers,
*self._get_mamba_cache_shape())
self.vllm_config,
num_mamba_layers,
*self._get_mamba_cache_shape(),
self.lm_head.weight.dtype,
self.lm_head.weight.dtype,
)
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)