[Model] RE: Mamba2 Prefill Performance Tweaks: Fixing Flurry of Unnecessary Memory Copies (#14857)

Signed-off-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com>
This commit is contained in:
Chih-Chieh Yang 2025-03-20 19:21:08 -07:00 committed by GitHub
parent 0032903a5b
commit 296f927f24
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -470,10 +470,11 @@ class MambaMixer2(CustomOp):
if has_prefill:
initial_states = None
if has_initial_states is not None and any(has_initial_states):
for idx in mamba_cache_params.state_indices_tensor[
~has_initial_states]:
mamba_cache_params.ssm_state[idx].zero_()
if has_initial_states is not None and torch.any(
has_initial_states):
zero_init_indices = mamba_cache_params.state_indices_tensor[
~has_initial_states]
mamba_cache_params.ssm_state[zero_init_indices] = 0
initial_states = mamba_cache_params.ssm_state[
mamba_cache_params.state_indices_tensor]
@ -499,8 +500,8 @@ class MambaMixer2(CustomOp):
# update ssm states
# - varlen state is a (batch, nheads, headdim, dstate) tensor
for i, idx in enumerate(mamba_cache_params.state_indices_tensor):
mamba_cache_params.ssm_state[idx].copy_(varlen_state[i])
mamba_cache_params.ssm_state[
mamba_cache_params.state_indices_tensor] = varlen_state
# - reshape
hidden_states = scan_output.view(seq_len, -1)