mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-16 15:07:16 +08:00
[BugFix] fix 3 issues: (1) using metadata for causal-conv1d, (2) indexing overflow in v1 vLLM, and (3) init_states in v0 (#20838)
Signed-off-by: Tuan M. Hoang-Trong <tmhoangt@us.ibm.com> Co-authored-by: Tuan M. Hoang-Trong <tmhoangt@us.ibm.com>
This commit is contained in:
parent
ed10f3cea1
commit
f29fd8a7f8
@ -573,8 +573,8 @@ class MambaMixer2(MambaBase, CustomOp):
|
||||
x = hidden_states_B_C_p.transpose(
|
||||
0, 1) # this is the form that causal-conv see
|
||||
if mamba2_metadata.cu_seqlen is None:
|
||||
mamba2_metadata = update_metadata(
|
||||
x, attn_metadata.query_start_loc, mamba2_metadata)
|
||||
mamba2_metadata = update_metadata(x, query_start_loc_p,
|
||||
mamba2_metadata)
|
||||
hidden_states_B_C_p = causal_conv1d_fn(
|
||||
x,
|
||||
conv_weights,
|
||||
@ -583,6 +583,7 @@ class MambaMixer2(MambaBase, CustomOp):
|
||||
conv_states=conv_state,
|
||||
has_initial_state=has_initial_states_p,
|
||||
cache_indices=state_indices_tensor_p,
|
||||
metadata=mamba2_metadata,
|
||||
query_start_loc=query_start_loc_p).transpose(
|
||||
0, 1)[:num_prefill_tokens]
|
||||
|
||||
@ -593,9 +594,14 @@ class MambaMixer2(MambaBase, CustomOp):
|
||||
initial_states = None
|
||||
if (has_initial_states_p is not None and prep_initial_states):
|
||||
# making a copy of the states
|
||||
initial_states = torch.where(
|
||||
has_initial_states_p[:, None, None, None],
|
||||
ssm_state[state_indices_tensor_p], 0)
|
||||
if envs.VLLM_USE_V1:
|
||||
initial_states = torch.where(
|
||||
has_initial_states_p[:, None, None, None],
|
||||
ssm_state[state_indices_tensor_p], 0)
|
||||
else:
|
||||
initial_states = torch.where(
|
||||
has_initial_states_p[:num_prefills, None, None, None],
|
||||
ssm_state[state_indices_tensor_p], 0)
|
||||
|
||||
scan_output, varlen_state = mamba_chunk_scan_combined(
|
||||
hidden_states_p.view(1, num_prefill_tokens,
|
||||
|
||||
@ -55,7 +55,6 @@ def _causal_conv1d_fwd_kernel( # continuous batching
|
||||
IS_CONTINUOUS_BATCHING: tl.constexpr,
|
||||
USE_PAD_SLOT: tl.constexpr,
|
||||
NP2_STATELEN: tl.constexpr,
|
||||
DECODE_SEQLEN: tl.constexpr,
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
):
|
||||
@ -416,7 +415,7 @@ def causal_conv1d_fn(
|
||||
activation = "silu"
|
||||
|
||||
args = None
|
||||
out = torch.zeros_like(x)
|
||||
out = torch.empty_like(x)
|
||||
if metadata is not None:
|
||||
cu_seqlen = metadata.cu_seqlen
|
||||
nums_dict = metadata.nums_dict
|
||||
@ -607,7 +606,6 @@ def causal_conv1d_fn(
|
||||
IS_CONTINUOUS_BATCHING=cache_indices is not None,
|
||||
USE_PAD_SLOT=pad_slot_id is not None,
|
||||
NP2_STATELEN=np2_statelen,
|
||||
DECODE_SEQLEN=1,
|
||||
#launch_cooperative_grid=True
|
||||
BLOCK_M=8,
|
||||
BLOCK_N=256,
|
||||
@ -665,7 +663,8 @@ def _causal_conv1d_update_kernel(
|
||||
|
||||
if IS_CONTINUOUS_BATCHING:
|
||||
# mask = idx_seq < batch
|
||||
conv_state_batch_coord = tl.load(conv_state_indices_ptr + idx_seq)
|
||||
conv_state_batch_coord = tl.load(conv_state_indices_ptr + idx_seq).to(
|
||||
tl.int64)
|
||||
else:
|
||||
conv_state_batch_coord = idx_seq
|
||||
if USE_PAD_SLOT: # noqa
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user