mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-07 12:05:45 +08:00
Revert "[Model] Mamba2 Prefill Performance Tweaks: Fixing Flurry of U… (#14848)
This commit is contained in:
parent
acaea3bb07
commit
ccf02fcbae
@ -466,17 +466,10 @@ class MambaMixer2(CustomOp):
|
|||||||
if has_prefill:
|
if has_prefill:
|
||||||
|
|
||||||
initial_states = None
|
initial_states = None
|
||||||
|
if has_initial_states is not None and any(has_initial_states):
|
||||||
if has_initial_states is not None and torch.any(
|
for idx in mamba_cache_params.state_indices_tensor[
|
||||||
has_initial_states):
|
~has_initial_states]:
|
||||||
|
mamba_cache_params.ssm_state[idx].zero_()
|
||||||
# vectorized ssm_state zero init
|
|
||||||
batched_zero_init_func = torch.vmap(
|
|
||||||
lambda idx: mamba_cache_params.ssm_state[idx].zero_())
|
|
||||||
batched_zero_init_func(
|
|
||||||
mamba_cache_params.
|
|
||||||
state_indices_tensor[~has_initial_states].unsqueeze(
|
|
||||||
dim=-1), )
|
|
||||||
initial_states = mamba_cache_params.ssm_state[
|
initial_states = mamba_cache_params.ssm_state[
|
||||||
mamba_cache_params.state_indices_tensor]
|
mamba_cache_params.state_indices_tensor]
|
||||||
|
|
||||||
@ -500,17 +493,10 @@ class MambaMixer2(CustomOp):
|
|||||||
dt_limit=(0.0, float("inf")),
|
dt_limit=(0.0, float("inf")),
|
||||||
)
|
)
|
||||||
|
|
||||||
# vectorized ssm state update using vmap
|
# update ssm states
|
||||||
# the 1d state_indices_tensor needs to be unsqueezed to avoid vmap
|
# - varlen state is a (batch, nheads, headdim, dstate) tensor
|
||||||
# limitation which doesn't allow use of `item()`
|
for i, idx in enumerate(mamba_cache_params.state_indices_tensor):
|
||||||
# Note: the lambda capture can happen where ssm_state is initialized
|
mamba_cache_params.ssm_state[idx].copy_(varlen_state[i])
|
||||||
# instead of here
|
|
||||||
batched_copy = torch.vmap(
|
|
||||||
lambda idx, source_state: mamba_cache_params.ssm_state[
|
|
||||||
idx].copy_(source_state))
|
|
||||||
batched_copy(
|
|
||||||
mamba_cache_params.state_indices_tensor.unsqueeze(dim=-1),
|
|
||||||
varlen_state)
|
|
||||||
|
|
||||||
# - reshape
|
# - reshape
|
||||||
hidden_states = scan_output.view(seq_len, -1)
|
hidden_states = scan_output.view(seq_len, -1)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user