[BugFix] fix 3 issues: (1) using metadata for causal-conv1d, (2) indexing overflow in v1 vLLM, and (3) init_states in v0 (#20838)

Signed-off-by: Tuan M. Hoang-Trong <tmhoangt@us.ibm.com>
Co-authored-by: Tuan M. Hoang-Trong <tmhoangt@us.ibm.com>
This commit is contained in:
Tuan, Hoang-Trong 2025-07-15 16:08:26 -04:00 committed by GitHub
parent ed10f3cea1
commit f29fd8a7f8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 14 additions and 9 deletions

View File

@ -573,8 +573,8 @@ class MambaMixer2(MambaBase, CustomOp):
x = hidden_states_B_C_p.transpose(
0, 1) # this is the form that causal-conv see
if mamba2_metadata.cu_seqlen is None:
mamba2_metadata = update_metadata(
x, attn_metadata.query_start_loc, mamba2_metadata)
mamba2_metadata = update_metadata(x, query_start_loc_p,
mamba2_metadata)
hidden_states_B_C_p = causal_conv1d_fn(
x,
conv_weights,
@ -583,6 +583,7 @@ class MambaMixer2(MambaBase, CustomOp):
conv_states=conv_state,
has_initial_state=has_initial_states_p,
cache_indices=state_indices_tensor_p,
metadata=mamba2_metadata,
query_start_loc=query_start_loc_p).transpose(
0, 1)[:num_prefill_tokens]
@ -593,9 +594,14 @@ class MambaMixer2(MambaBase, CustomOp):
initial_states = None
if (has_initial_states_p is not None and prep_initial_states):
# making a copy of the states
initial_states = torch.where(
has_initial_states_p[:, None, None, None],
ssm_state[state_indices_tensor_p], 0)
if envs.VLLM_USE_V1:
initial_states = torch.where(
has_initial_states_p[:, None, None, None],
ssm_state[state_indices_tensor_p], 0)
else:
initial_states = torch.where(
has_initial_states_p[:num_prefills, None, None, None],
ssm_state[state_indices_tensor_p], 0)
scan_output, varlen_state = mamba_chunk_scan_combined(
hidden_states_p.view(1, num_prefill_tokens,

View File

@ -55,7 +55,6 @@ def _causal_conv1d_fwd_kernel( # continuous batching
IS_CONTINUOUS_BATCHING: tl.constexpr,
USE_PAD_SLOT: tl.constexpr,
NP2_STATELEN: tl.constexpr,
DECODE_SEQLEN: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
):
@ -416,7 +415,7 @@ def causal_conv1d_fn(
activation = "silu"
args = None
out = torch.zeros_like(x)
out = torch.empty_like(x)
if metadata is not None:
cu_seqlen = metadata.cu_seqlen
nums_dict = metadata.nums_dict
@ -607,7 +606,6 @@ def causal_conv1d_fn(
IS_CONTINUOUS_BATCHING=cache_indices is not None,
USE_PAD_SLOT=pad_slot_id is not None,
NP2_STATELEN=np2_statelen,
DECODE_SEQLEN=1,
#launch_cooperative_grid=True
BLOCK_M=8,
BLOCK_N=256,
@ -665,7 +663,8 @@ def _causal_conv1d_update_kernel(
if IS_CONTINUOUS_BATCHING:
# mask = idx_seq < batch
conv_state_batch_coord = tl.load(conv_state_indices_ptr + idx_seq)
conv_state_batch_coord = tl.load(conv_state_indices_ptr + idx_seq).to(
tl.int64)
else:
conv_state_batch_coord = idx_seq
if USE_PAD_SLOT: # noqa