From fe66b34728e5d383e3d19aefc544eeee808c99fb Mon Sep 17 00:00:00 2001 From: Chih-Chieh Yang <7364402+cyang49@users.noreply.github.com> Date: Fri, 14 Mar 2025 16:36:18 -0400 Subject: [PATCH] [Model] Mamba2 Prefill Performance Tweaks: Fixing Flurry of Unnecessary Memory Copies (#14778) Signed-off-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com> --- .../layers/mamba/mamba_mixer2.py | 30 ++++++++++++++----- 1 file changed, 22 insertions(+), 8 deletions(-) diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index b53a540ed662..5b19e3f3554a 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -466,10 +466,17 @@ class MambaMixer2(CustomOp): if has_prefill: initial_states = None - if has_initial_states is not None and any(has_initial_states): - for idx in mamba_cache_params.state_indices_tensor[ - ~has_initial_states]: - mamba_cache_params.ssm_state[idx].zero_() + + if has_initial_states is not None and torch.any( + has_initial_states): + + # vectorized ssm_state zero init + batched_zero_init_func = torch.vmap( + lambda idx: mamba_cache_params.ssm_state[idx].zero_()) + batched_zero_init_func( + mamba_cache_params. + state_indices_tensor[~has_initial_states].unsqueeze( + dim=-1), ) initial_states = mamba_cache_params.ssm_state[ mamba_cache_params.state_indices_tensor] @@ -493,10 +500,17 @@ class MambaMixer2(CustomOp): dt_limit=(0.0, float("inf")), ) - # update ssm states - # - varlen state is a (batch, nheads, headdim, dstate) tensor - for i, idx in enumerate(mamba_cache_params.state_indices_tensor): - mamba_cache_params.ssm_state[idx].copy_(varlen_state[i]) + # vectorized ssm state update using vmap + # the 1d state_indices_tensor needs to be unsqueezed to avoid vmap + # limitation which doesn't allow use of `item()` + # Note: the lambda capture can happen where ssm_state is initialized + # instead of here + batched_copy = torch.vmap( + lambda idx, source_state: mamba_cache_params.ssm_state[ + idx].copy_(source_state)) + batched_copy( + mamba_cache_params.state_indices_tensor.unsqueeze(dim=-1), + varlen_state) # - reshape hidden_states = scan_output.view(seq_len, -1)