mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 06:45:01 +08:00
[BugFix] Fix regression caused by mamba state dtype PR (#22998)
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
This commit is contained in:
parent
177e55e3bd
commit
f5d412bafb
@ -650,8 +650,12 @@ class Phi4FlashForCausalLM(nn.Module, HasInnerState, IsHybrid, SupportsV0Only):
|
||||
num_mamba_layers = self.config.num_hidden_layers \
|
||||
// 2 // self.config.mb_per_layer + 1
|
||||
self.mamba_cache = MambaCacheManager(
|
||||
self.vllm_config, self.lm_head.weight.dtype, num_mamba_layers,
|
||||
*self._get_mamba_cache_shape())
|
||||
self.vllm_config,
|
||||
num_mamba_layers,
|
||||
*self._get_mamba_cache_shape(),
|
||||
self.lm_head.weight.dtype,
|
||||
self.lm_head.weight.dtype,
|
||||
)
|
||||
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
|
||||
|
||||
attn_metadata = get_forward_context().attn_metadata
|
||||
|
||||
@ -767,8 +767,12 @@ class Plamo2ForCausalLM(Plamo2PreTrainedModel, HasInnerState, SupportsPP,
|
||||
self.vllm_config.parallel_config, LayerBlockType.mamba)
|
||||
|
||||
self.mamba_cache = MambaCacheManager(
|
||||
self.vllm_config, self.lm_head.weight.dtype, num_mamba_layers,
|
||||
*self._get_mamba_cache_shape())
|
||||
self.vllm_config,
|
||||
num_mamba_layers,
|
||||
*self._get_mamba_cache_shape(),
|
||||
self.lm_head.weight.dtype,
|
||||
self.lm_head.weight.dtype,
|
||||
)
|
||||
|
||||
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user