From 296f927f2493908984707354e3cc5d7b2e41650b Mon Sep 17 00:00:00 2001 From: Chih-Chieh Yang <7364402+cyang49@users.noreply.github.com> Date: Thu, 20 Mar 2025 19:21:08 -0700 Subject: [PATCH] [Model] RE: Mamba2 Prefill Performance Tweaks: Fixing Flurry of Unnecessary Memory Copies (#14857) Signed-off-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com> --- vllm/model_executor/layers/mamba/mamba_mixer2.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index fec6d6112d665..d7a45bc51239a 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -470,10 +470,11 @@ class MambaMixer2(CustomOp): if has_prefill: initial_states = None - if has_initial_states is not None and any(has_initial_states): - for idx in mamba_cache_params.state_indices_tensor[ - ~has_initial_states]: - mamba_cache_params.ssm_state[idx].zero_() + if has_initial_states is not None and torch.any( + has_initial_states): + zero_init_indices = mamba_cache_params.state_indices_tensor[ + ~has_initial_states] + mamba_cache_params.ssm_state[zero_init_indices] = 0 initial_states = mamba_cache_params.ssm_state[ mamba_cache_params.state_indices_tensor] @@ -499,8 +500,8 @@ class MambaMixer2(CustomOp): # update ssm states # - varlen state is a (batch, nheads, headdim, dstate) tensor - for i, idx in enumerate(mamba_cache_params.state_indices_tensor): - mamba_cache_params.ssm_state[idx].copy_(varlen_state[i]) + mamba_cache_params.ssm_state[ + mamba_cache_params.state_indices_tensor] = varlen_state # - reshape hidden_states = scan_output.view(seq_len, -1)