mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-23 13:27:07 +08:00
[BugFix] fix 3 issues: (1) using metadata for causal-conv1d, (2) indexing overflow in v1 vLLM, and (3) init_states in v0 (#20838)
Signed-off-by: Tuan M. Hoang-Trong <tmhoangt@us.ibm.com> Co-authored-by: Tuan M. Hoang-Trong <tmhoangt@us.ibm.com>
This commit is contained in:
parent
ed10f3cea1
commit
f29fd8a7f8
@ -573,8 +573,8 @@ class MambaMixer2(MambaBase, CustomOp):
|
|||||||
x = hidden_states_B_C_p.transpose(
|
x = hidden_states_B_C_p.transpose(
|
||||||
0, 1) # this is the form that causal-conv see
|
0, 1) # this is the form that causal-conv see
|
||||||
if mamba2_metadata.cu_seqlen is None:
|
if mamba2_metadata.cu_seqlen is None:
|
||||||
mamba2_metadata = update_metadata(
|
mamba2_metadata = update_metadata(x, query_start_loc_p,
|
||||||
x, attn_metadata.query_start_loc, mamba2_metadata)
|
mamba2_metadata)
|
||||||
hidden_states_B_C_p = causal_conv1d_fn(
|
hidden_states_B_C_p = causal_conv1d_fn(
|
||||||
x,
|
x,
|
||||||
conv_weights,
|
conv_weights,
|
||||||
@ -583,6 +583,7 @@ class MambaMixer2(MambaBase, CustomOp):
|
|||||||
conv_states=conv_state,
|
conv_states=conv_state,
|
||||||
has_initial_state=has_initial_states_p,
|
has_initial_state=has_initial_states_p,
|
||||||
cache_indices=state_indices_tensor_p,
|
cache_indices=state_indices_tensor_p,
|
||||||
|
metadata=mamba2_metadata,
|
||||||
query_start_loc=query_start_loc_p).transpose(
|
query_start_loc=query_start_loc_p).transpose(
|
||||||
0, 1)[:num_prefill_tokens]
|
0, 1)[:num_prefill_tokens]
|
||||||
|
|
||||||
@ -593,9 +594,14 @@ class MambaMixer2(MambaBase, CustomOp):
|
|||||||
initial_states = None
|
initial_states = None
|
||||||
if (has_initial_states_p is not None and prep_initial_states):
|
if (has_initial_states_p is not None and prep_initial_states):
|
||||||
# making a copy of the states
|
# making a copy of the states
|
||||||
initial_states = torch.where(
|
if envs.VLLM_USE_V1:
|
||||||
has_initial_states_p[:, None, None, None],
|
initial_states = torch.where(
|
||||||
ssm_state[state_indices_tensor_p], 0)
|
has_initial_states_p[:, None, None, None],
|
||||||
|
ssm_state[state_indices_tensor_p], 0)
|
||||||
|
else:
|
||||||
|
initial_states = torch.where(
|
||||||
|
has_initial_states_p[:num_prefills, None, None, None],
|
||||||
|
ssm_state[state_indices_tensor_p], 0)
|
||||||
|
|
||||||
scan_output, varlen_state = mamba_chunk_scan_combined(
|
scan_output, varlen_state = mamba_chunk_scan_combined(
|
||||||
hidden_states_p.view(1, num_prefill_tokens,
|
hidden_states_p.view(1, num_prefill_tokens,
|
||||||
|
|||||||
@ -55,7 +55,6 @@ def _causal_conv1d_fwd_kernel( # continuous batching
|
|||||||
IS_CONTINUOUS_BATCHING: tl.constexpr,
|
IS_CONTINUOUS_BATCHING: tl.constexpr,
|
||||||
USE_PAD_SLOT: tl.constexpr,
|
USE_PAD_SLOT: tl.constexpr,
|
||||||
NP2_STATELEN: tl.constexpr,
|
NP2_STATELEN: tl.constexpr,
|
||||||
DECODE_SEQLEN: tl.constexpr,
|
|
||||||
BLOCK_M: tl.constexpr,
|
BLOCK_M: tl.constexpr,
|
||||||
BLOCK_N: tl.constexpr,
|
BLOCK_N: tl.constexpr,
|
||||||
):
|
):
|
||||||
@ -416,7 +415,7 @@ def causal_conv1d_fn(
|
|||||||
activation = "silu"
|
activation = "silu"
|
||||||
|
|
||||||
args = None
|
args = None
|
||||||
out = torch.zeros_like(x)
|
out = torch.empty_like(x)
|
||||||
if metadata is not None:
|
if metadata is not None:
|
||||||
cu_seqlen = metadata.cu_seqlen
|
cu_seqlen = metadata.cu_seqlen
|
||||||
nums_dict = metadata.nums_dict
|
nums_dict = metadata.nums_dict
|
||||||
@ -607,7 +606,6 @@ def causal_conv1d_fn(
|
|||||||
IS_CONTINUOUS_BATCHING=cache_indices is not None,
|
IS_CONTINUOUS_BATCHING=cache_indices is not None,
|
||||||
USE_PAD_SLOT=pad_slot_id is not None,
|
USE_PAD_SLOT=pad_slot_id is not None,
|
||||||
NP2_STATELEN=np2_statelen,
|
NP2_STATELEN=np2_statelen,
|
||||||
DECODE_SEQLEN=1,
|
|
||||||
#launch_cooperative_grid=True
|
#launch_cooperative_grid=True
|
||||||
BLOCK_M=8,
|
BLOCK_M=8,
|
||||||
BLOCK_N=256,
|
BLOCK_N=256,
|
||||||
@ -665,7 +663,8 @@ def _causal_conv1d_update_kernel(
|
|||||||
|
|
||||||
if IS_CONTINUOUS_BATCHING:
|
if IS_CONTINUOUS_BATCHING:
|
||||||
# mask = idx_seq < batch
|
# mask = idx_seq < batch
|
||||||
conv_state_batch_coord = tl.load(conv_state_indices_ptr + idx_seq)
|
conv_state_batch_coord = tl.load(conv_state_indices_ptr + idx_seq).to(
|
||||||
|
tl.int64)
|
||||||
else:
|
else:
|
||||||
conv_state_batch_coord = idx_seq
|
conv_state_batch_coord = idx_seq
|
||||||
if USE_PAD_SLOT: # noqa
|
if USE_PAD_SLOT: # noqa
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user