mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 18:46:25 +08:00
[Model] RE: Mamba2 Prefill Performance Tweaks: Fixing Flurry of Unnecessary Memory Copies (#14857)
Signed-off-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com>
This commit is contained in:
parent
0032903a5b
commit
296f927f24
@ -470,10 +470,11 @@ class MambaMixer2(CustomOp):
|
||||
if has_prefill:
|
||||
|
||||
initial_states = None
|
||||
if has_initial_states is not None and any(has_initial_states):
|
||||
for idx in mamba_cache_params.state_indices_tensor[
|
||||
~has_initial_states]:
|
||||
mamba_cache_params.ssm_state[idx].zero_()
|
||||
if has_initial_states is not None and torch.any(
|
||||
has_initial_states):
|
||||
zero_init_indices = mamba_cache_params.state_indices_tensor[
|
||||
~has_initial_states]
|
||||
mamba_cache_params.ssm_state[zero_init_indices] = 0
|
||||
initial_states = mamba_cache_params.ssm_state[
|
||||
mamba_cache_params.state_indices_tensor]
|
||||
|
||||
@ -499,8 +500,8 @@ class MambaMixer2(CustomOp):
|
||||
|
||||
# update ssm states
|
||||
# - varlen state is a (batch, nheads, headdim, dstate) tensor
|
||||
for i, idx in enumerate(mamba_cache_params.state_indices_tensor):
|
||||
mamba_cache_params.ssm_state[idx].copy_(varlen_state[i])
|
||||
mamba_cache_params.ssm_state[
|
||||
mamba_cache_params.state_indices_tensor] = varlen_state
|
||||
|
||||
# - reshape
|
||||
hidden_states = scan_output.view(seq_len, -1)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user