Revert "[Model] Mamba2 Prefill Performance Tweaks: Fixing Flurry of U… (#14848)

This commit is contained in:
Tyler Michael Smith 2025-03-14 23:45:42 -04:00 committed by GitHub
parent acaea3bb07
commit ccf02fcbae
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -466,17 +466,10 @@ class MambaMixer2(CustomOp):
if has_prefill: if has_prefill:
initial_states = None initial_states = None
if has_initial_states is not None and any(has_initial_states):
if has_initial_states is not None and torch.any( for idx in mamba_cache_params.state_indices_tensor[
has_initial_states): ~has_initial_states]:
mamba_cache_params.ssm_state[idx].zero_()
# vectorized ssm_state zero init
batched_zero_init_func = torch.vmap(
lambda idx: mamba_cache_params.ssm_state[idx].zero_())
batched_zero_init_func(
mamba_cache_params.
state_indices_tensor[~has_initial_states].unsqueeze(
dim=-1), )
initial_states = mamba_cache_params.ssm_state[ initial_states = mamba_cache_params.ssm_state[
mamba_cache_params.state_indices_tensor] mamba_cache_params.state_indices_tensor]
@ -500,17 +493,10 @@ class MambaMixer2(CustomOp):
dt_limit=(0.0, float("inf")), dt_limit=(0.0, float("inf")),
) )
# vectorized ssm state update using vmap # update ssm states
# the 1d state_indices_tensor needs to be unsqueezed to avoid vmap # - varlen state is a (batch, nheads, headdim, dstate) tensor
# limitation which doesn't allow use of `item()` for i, idx in enumerate(mamba_cache_params.state_indices_tensor):
# Note: the lambda capture can happen where ssm_state is initialized mamba_cache_params.ssm_state[idx].copy_(varlen_state[i])
# instead of here
batched_copy = torch.vmap(
lambda idx, source_state: mamba_cache_params.ssm_state[
idx].copy_(source_state))
batched_copy(
mamba_cache_params.state_indices_tensor.unsqueeze(dim=-1),
varlen_state)
# - reshape # - reshape
hidden_states = scan_output.view(seq_len, -1) hidden_states = scan_output.view(seq_len, -1)