mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-03 08:44:25 +08:00
[V1] [Hybrid] Some additional clean-up in Mamba2 prefix caching (#26222)
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
This commit is contained in:
parent
d3c84297c3
commit
778f554157
@ -595,21 +595,32 @@ class MambaMixer2(MambaBase, CustomOp):
|
|||||||
if prefix_caching_enabled:
|
if prefix_caching_enabled:
|
||||||
# If prefix caching is enabled, retrieve the relevant variables
|
# If prefix caching is enabled, retrieve the relevant variables
|
||||||
# for prefill and decode
|
# for prefill and decode
|
||||||
last_state_idx_d, last_state_idx_p = torch.split(
|
block_idx_last_computed_token_d, block_idx_last_computed_token_p = (
|
||||||
attn_metadata.last_state_idx, [num_decodes, num_prefills], dim=0
|
torch.split(
|
||||||
|
attn_metadata.block_idx_last_computed_token,
|
||||||
|
[num_decodes, num_prefills],
|
||||||
|
dim=0,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
current_last_idx_d, current_last_idx_p = torch.split(
|
block_idx_last_scheduled_token_d, block_idx_last_scheduled_token_p = (
|
||||||
attn_metadata.current_last_idx, [num_decodes, num_prefills], dim=0
|
torch.split(
|
||||||
|
attn_metadata.block_idx_last_scheduled_token,
|
||||||
|
[num_decodes, num_prefills],
|
||||||
|
dim=0,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
# Prefill-only variables:
|
# Prefill-only variables:
|
||||||
current_first_idx_p = attn_metadata.current_first_idx_p
|
block_idx_first_scheduled_token_p = (
|
||||||
context_lens_p = attn_metadata.context_lens_p
|
attn_metadata.block_idx_first_scheduled_token_p
|
||||||
last_computed_offset_p = attn_metadata.last_computed_offset_p
|
)
|
||||||
|
num_computed_tokens_p = attn_metadata.num_computed_tokens_p
|
||||||
else:
|
else:
|
||||||
last_state_idx_d, last_state_idx_p = None, None
|
block_idx_last_computed_token_d = None
|
||||||
current_last_idx_d, current_last_idx_p = None, None
|
block_idx_last_computed_token_p = None
|
||||||
current_first_idx_p = None
|
block_idx_last_scheduled_token_d = None
|
||||||
context_lens_p = None
|
block_idx_last_scheduled_token_p = None
|
||||||
|
block_idx_first_scheduled_token_p = None
|
||||||
|
num_computed_tokens_p = None
|
||||||
|
|
||||||
# Preallocate output tensor to avoid memcpy cost for merging prefill
|
# Preallocate output tensor to avoid memcpy cost for merging prefill
|
||||||
# and decode outputs
|
# and decode outputs
|
||||||
@ -637,7 +648,8 @@ class MambaMixer2(MambaBase, CustomOp):
|
|||||||
# to by "state_indices_tensor_p".
|
# to by "state_indices_tensor_p".
|
||||||
# In particular, it will always write the state at the
|
# In particular, it will always write the state at the
|
||||||
# sequence end.
|
# sequence end.
|
||||||
# In addition, "current_first_idx_p" and "current_last_idx_p"
|
# In addition, "block_idx_first_scheduled_token_p" and
|
||||||
|
# "block_idx_last_scheduled_token_p"
|
||||||
# are provided (which are pointers into
|
# are provided (which are pointers into
|
||||||
# "state_indices_tensor_p"), it will write additional cache
|
# "state_indices_tensor_p"), it will write additional cache
|
||||||
# states aligned at "block_size_to_align".
|
# states aligned at "block_size_to_align".
|
||||||
@ -652,10 +664,10 @@ class MambaMixer2(MambaBase, CustomOp):
|
|||||||
conv_states=conv_state,
|
conv_states=conv_state,
|
||||||
has_initial_state=has_initial_states_p,
|
has_initial_state=has_initial_states_p,
|
||||||
cache_indices=state_indices_tensor_p,
|
cache_indices=state_indices_tensor_p,
|
||||||
current_first_idx=current_first_idx_p,
|
block_idx_first_scheduled_token=block_idx_first_scheduled_token_p,
|
||||||
current_last_idx=current_last_idx_p,
|
block_idx_last_scheduled_token=block_idx_last_scheduled_token_p,
|
||||||
initial_state_idx=last_state_idx_p,
|
initial_state_idx=block_idx_last_computed_token_p,
|
||||||
context_lens=context_lens_p,
|
num_computed_tokens=num_computed_tokens_p,
|
||||||
block_size_to_align=mamba_block_size,
|
block_size_to_align=mamba_block_size,
|
||||||
metadata=attn_metadata,
|
metadata=attn_metadata,
|
||||||
query_start_loc=query_start_loc_p,
|
query_start_loc=query_start_loc_p,
|
||||||
@ -669,7 +681,7 @@ class MambaMixer2(MambaBase, CustomOp):
|
|||||||
kernel_ssm_indices = state_indices_tensor_p
|
kernel_ssm_indices = state_indices_tensor_p
|
||||||
if prefix_caching_enabled:
|
if prefix_caching_enabled:
|
||||||
kernel_ssm_indices = state_indices_tensor_p.gather(
|
kernel_ssm_indices = state_indices_tensor_p.gather(
|
||||||
1, last_state_idx_p.unsqueeze(1)
|
1, block_idx_last_computed_token_p.unsqueeze(1)
|
||||||
).squeeze(1)
|
).squeeze(1)
|
||||||
initial_states = torch.where(
|
initial_states = torch.where(
|
||||||
has_initial_states_p[:, None, None, None],
|
has_initial_states_p[:, None, None, None],
|
||||||
@ -703,52 +715,76 @@ class MambaMixer2(MambaBase, CustomOp):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if prefix_caching_enabled:
|
if prefix_caching_enabled:
|
||||||
# Save states for sequences with more than just the final state:
|
# The chunk_stride is the number of chunks per mamba block
|
||||||
n_blocks_to_fill = current_last_idx_p - current_first_idx_p
|
# e.g., if mamba_block_size = 512 and chunk_size = 256,
|
||||||
for seq_idx in (n_blocks_to_fill > 0).nonzero().squeeze(1):
|
# then chunk_stride = 2
|
||||||
|
chunk_stride = mamba_block_size // chunk_size
|
||||||
|
|
||||||
|
# Save state for sequences with more than just final state
|
||||||
|
for seq_idx in range(num_prefills):
|
||||||
|
# Block index for the first scheduled token
|
||||||
|
block_idx_first_scheduled_token = block_idx_first_scheduled_token_p[
|
||||||
|
seq_idx
|
||||||
|
]
|
||||||
|
|
||||||
|
# Block index for the last scheduled token
|
||||||
|
block_idx_last_scheduled_token = block_idx_last_scheduled_token_p[
|
||||||
|
seq_idx
|
||||||
|
]
|
||||||
|
|
||||||
|
# Number of blocks that need to be written
|
||||||
|
n_blocks_to_fill = (
|
||||||
|
block_idx_last_scheduled_token - block_idx_first_scheduled_token
|
||||||
|
)
|
||||||
|
|
||||||
|
# Skip sequences that don't have any blocks to fill
|
||||||
|
if n_blocks_to_fill == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Look up the state indices
|
||||||
cache_blocks_to_fill = state_indices_tensor_p[
|
cache_blocks_to_fill = state_indices_tensor_p[
|
||||||
seq_idx,
|
seq_idx,
|
||||||
current_first_idx_p[seq_idx] : current_first_idx_p[seq_idx]
|
block_idx_first_scheduled_token:block_idx_last_scheduled_token,
|
||||||
+ n_blocks_to_fill[seq_idx],
|
|
||||||
]
|
]
|
||||||
# chunks = [0 1 2 3 4 5 6 ...]
|
|
||||||
# First aligned chunk would typically be:
|
# First chunk index for this sequence
|
||||||
# mamba_block_size = 1024, chunk_size = 256
|
if seq_idx == 0:
|
||||||
# 1024 // 256 - 1 --> chunks[3]
|
first_chunk = 0
|
||||||
# But when last chunk wasn't block aligned:
|
else:
|
||||||
# - last_computed_offset_p[seq_idx] // chunk_size
|
first_chunk = 1 + last_chunk_indices_p[seq_idx - 1]
|
||||||
# e.g. 1000 // 256 -> 3 completed --> store chunk[0]
|
|
||||||
# e.g. 513 // 256 -> 2 completed --> store chunk[1] (skip 1)
|
# First chunk that is aligned on the mamba block boundary
|
||||||
# e.g. 256 // 256 -> 1 completed --> store chunk[2] (skip 2)
|
first_aligned_chunk = first_chunk + chunk_stride - 1
|
||||||
# e.g. 10 // 256 -> 0 completed --> store chunk[3] (skip 3)
|
|
||||||
chunk_stride = mamba_block_size // chunk_size
|
# Calculate the number of computed tokens that were not
|
||||||
first_aligned_chunk = (
|
# already cached
|
||||||
torch.concat(
|
num_unaligned_computed_tokens = (
|
||||||
[
|
num_computed_tokens_p[seq_idx] % mamba_block_size
|
||||||
torch.zeros(
|
|
||||||
1,
|
|
||||||
dtype=last_chunk_indices_p.dtype,
|
|
||||||
device=last_chunk_indices_p.device,
|
|
||||||
),
|
|
||||||
last_chunk_indices_p + 1,
|
|
||||||
]
|
|
||||||
)[seq_idx]
|
|
||||||
+ chunk_stride
|
|
||||||
- 1
|
|
||||||
- last_computed_offset_p[seq_idx] // chunk_size
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if num_unaligned_computed_tokens > 0:
|
||||||
|
# If the number of computed tokens is not block aligned,
|
||||||
|
# then we need to shift the index accordingly
|
||||||
|
first_aligned_chunk -= (
|
||||||
|
num_unaligned_computed_tokens // chunk_size
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get states to write
|
||||||
from_where = varlen_states[
|
from_where = varlen_states[
|
||||||
first_aligned_chunk : first_aligned_chunk
|
first_aligned_chunk : first_aligned_chunk
|
||||||
+ n_blocks_to_fill[seq_idx] * chunk_stride : chunk_stride
|
+ n_blocks_to_fill * chunk_stride : chunk_stride
|
||||||
]
|
]
|
||||||
|
|
||||||
|
# Write the states
|
||||||
ssm_state[cache_blocks_to_fill] = from_where
|
ssm_state[cache_blocks_to_fill] = from_where
|
||||||
|
|
||||||
# For all seqs, store the last state (Note: might be partial):
|
# For all seqs, store the last state (note: might be partial):
|
||||||
ssm_state[
|
ssm_state[
|
||||||
state_indices_tensor_p.gather(
|
state_indices_tensor_p.gather(
|
||||||
1, current_last_idx_p.unsqueeze(1)
|
1, block_idx_last_scheduled_token_p.unsqueeze(1)
|
||||||
).squeeze(1)
|
).squeeze(1)
|
||||||
] = varlen_states[last_chunk_indices_p]
|
] = varlen_states[last_chunk_indices_p]
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# update ssm states
|
# update ssm states
|
||||||
# - varlen state is a (num_prefills, nheads, headdim, dstate)
|
# - varlen state is a (num_prefills, nheads, headdim, dstate)
|
||||||
@ -759,14 +795,17 @@ class MambaMixer2(MambaBase, CustomOp):
|
|||||||
if has_decode:
|
if has_decode:
|
||||||
if prefix_caching_enabled:
|
if prefix_caching_enabled:
|
||||||
state_indices_tensor_d_input = state_indices_tensor_d.gather(
|
state_indices_tensor_d_input = state_indices_tensor_d.gather(
|
||||||
1, last_state_idx_d.unsqueeze(1)
|
1, block_idx_last_computed_token_d.unsqueeze(1)
|
||||||
).squeeze(1)
|
).squeeze(1)
|
||||||
state_indices_tensor_d_output = state_indices_tensor_d.gather(
|
state_indices_tensor_d_output = state_indices_tensor_d.gather(
|
||||||
1, current_last_idx_d.unsqueeze(1)
|
1, block_idx_last_scheduled_token_d.unsqueeze(1)
|
||||||
).squeeze(1)
|
).squeeze(1)
|
||||||
# Note:
|
# for decode:
|
||||||
# for decode always: current_first_idx_d == current_last_idx_d
|
# block_idx_first_scheduled_token_d ==
|
||||||
# at block boundaries: current_first_idx_d > last_state_idx_d
|
# block_idx_last_scheduled_token_d
|
||||||
|
# at block boundaries:
|
||||||
|
# block_idx_first_scheduled_token_d >
|
||||||
|
# block_idx_last_computed_token_d
|
||||||
else:
|
else:
|
||||||
# Without caching, read and write in-place to the same blocks:
|
# Without caching, read and write in-place to the same blocks:
|
||||||
state_indices_tensor_d_input = state_indices_tensor_d
|
state_indices_tensor_d_input = state_indices_tensor_d
|
||||||
@ -780,8 +819,8 @@ class MambaMixer2(MambaBase, CustomOp):
|
|||||||
self.conv1d.bias,
|
self.conv1d.bias,
|
||||||
self.activation,
|
self.activation,
|
||||||
conv_state_indices=state_indices_tensor_d,
|
conv_state_indices=state_indices_tensor_d,
|
||||||
current_last_idx=current_last_idx_d,
|
block_idx_last_scheduled_token=block_idx_last_scheduled_token_d,
|
||||||
initial_state_idx=last_state_idx_d,
|
initial_state_idx=block_idx_last_computed_token_d,
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states_d, B_d, C_d = split_hidden_states_B_C_fn(hidden_states_B_C_d)
|
hidden_states_d, B_d, C_d = split_hidden_states_B_C_fn(hidden_states_B_C_d)
|
||||||
|
|||||||
@ -27,10 +27,10 @@ def _causal_conv1d_fwd_kernel( # continuous batching
|
|||||||
query_start_loc_ptr,
|
query_start_loc_ptr,
|
||||||
batch_ptr,
|
batch_ptr,
|
||||||
token_chunk_offset_ptr,
|
token_chunk_offset_ptr,
|
||||||
current_first_idx, # (batch,)
|
block_idx_first_scheduled_token, # (batch,)
|
||||||
current_last_idx, # (batch,)
|
block_idx_last_scheduled_token, # (batch,)
|
||||||
initial_state_idx, # (batch,)
|
initial_state_idx, # (batch,)
|
||||||
context_lens, # (batch,)
|
num_computed_tokens, # (batch,)
|
||||||
o_ptr, # (dim, seqlen) - actually pointing to x_ptr
|
o_ptr, # (dim, seqlen) - actually pointing to x_ptr
|
||||||
# Matrix dimensions
|
# Matrix dimensions
|
||||||
dim: tl.constexpr,
|
dim: tl.constexpr,
|
||||||
@ -94,9 +94,9 @@ def _causal_conv1d_fwd_kernel( # continuous batching
|
|||||||
# In particular, if prefix caching is enabled, the program write additional cache states to "cache_indices_ptr"
|
# In particular, if prefix caching is enabled, the program write additional cache states to "cache_indices_ptr"
|
||||||
|
|
||||||
# Get the length of the completed sequence so far and compute the offset.
|
# Get the length of the completed sequence so far and compute the offset.
|
||||||
current_first_index = tl.load(current_first_idx + idx_seq)
|
current_first_index = tl.load(block_idx_first_scheduled_token + idx_seq)
|
||||||
current_last_index = tl.load(current_last_idx + idx_seq)
|
current_last_index = tl.load(block_idx_last_scheduled_token + idx_seq)
|
||||||
sequence_completed_index = tl.load(context_lens + idx_seq)
|
sequence_completed_index = tl.load(num_computed_tokens + idx_seq)
|
||||||
|
|
||||||
# Compute the offset where the first stride_block_m-aligned first full block is
|
# Compute the offset where the first stride_block_m-aligned first full block is
|
||||||
# Value in "token-space"
|
# Value in "token-space"
|
||||||
@ -476,10 +476,10 @@ def causal_conv1d_fn(
|
|||||||
has_initial_state: Optional[torch.Tensor] = None,
|
has_initial_state: Optional[torch.Tensor] = None,
|
||||||
activation: Optional[str] = "silu",
|
activation: Optional[str] = "silu",
|
||||||
pad_slot_id: int = PAD_SLOT_ID,
|
pad_slot_id: int = PAD_SLOT_ID,
|
||||||
current_first_idx: Optional[torch.Tensor] = None,
|
block_idx_first_scheduled_token: Optional[torch.Tensor] = None,
|
||||||
current_last_idx: Optional[torch.Tensor] = None,
|
block_idx_last_scheduled_token: Optional[torch.Tensor] = None,
|
||||||
initial_state_idx: Optional[torch.Tensor] = None,
|
initial_state_idx: Optional[torch.Tensor] = None,
|
||||||
context_lens: Optional[torch.Tensor] = None,
|
num_computed_tokens: Optional[torch.Tensor] = None,
|
||||||
block_size_to_align=0,
|
block_size_to_align=0,
|
||||||
metadata=None,
|
metadata=None,
|
||||||
validate_data=False,
|
validate_data=False,
|
||||||
@ -523,13 +523,13 @@ def causal_conv1d_fn(
|
|||||||
for example: cache_indices = [pad_slot_id, 1, 20, pad_slot_id]
|
for example: cache_indices = [pad_slot_id, 1, 20, pad_slot_id]
|
||||||
in this case, the kernel will not process entries at
|
in this case, the kernel will not process entries at
|
||||||
indices 0 and 3
|
indices 0 and 3
|
||||||
current_first_idx: (batch,), dtype int32
|
block_idx_first_scheduled_token: (batch,), dtype int32
|
||||||
The pointer into cache_indices, where the first cache block to be filled is located.
|
The pointer into cache_indices, where the first cache block to be filled is located.
|
||||||
current_last_idx: (batch,), dtype int32
|
block_idx_last_scheduled_token: (batch,), dtype int32
|
||||||
The pointer into cache_indices, where the last cache block to be filled is located.
|
The pointer into cache_indices, where the last cache block to be filled is located.
|
||||||
initial_state_idx: (batch,), dtype int32
|
initial_state_idx: (batch,), dtype int32
|
||||||
The pointer into cache_indices, where the cache block containing the initial state is located.
|
The pointer into cache_indices, where the cache block containing the initial state is located.
|
||||||
context_lens: (batch,), dtype int32
|
num_computed_tokens: (batch,), dtype int32
|
||||||
The number of tokens already completed for each sequence
|
The number of tokens already completed for each sequence
|
||||||
block_size_to_align: int
|
block_size_to_align: int
|
||||||
The block size to align the cached states to
|
The block size to align the cached states to
|
||||||
@ -708,10 +708,10 @@ def causal_conv1d_fn(
|
|||||||
query_start_loc,
|
query_start_loc,
|
||||||
batch_ptr,
|
batch_ptr,
|
||||||
token_chunk_offset_ptr,
|
token_chunk_offset_ptr,
|
||||||
current_first_idx,
|
block_idx_first_scheduled_token,
|
||||||
current_last_idx,
|
block_idx_last_scheduled_token,
|
||||||
initial_state_idx,
|
initial_state_idx,
|
||||||
context_lens,
|
num_computed_tokens,
|
||||||
out,
|
out,
|
||||||
# Matrix dimensions
|
# Matrix dimensions
|
||||||
dim,
|
dim,
|
||||||
@ -735,7 +735,7 @@ def causal_conv1d_fn(
|
|||||||
HAS_BIAS=bias is not None,
|
HAS_BIAS=bias is not None,
|
||||||
KERNEL_WIDTH=width,
|
KERNEL_WIDTH=width,
|
||||||
SILU_ACTIVATION=activation in ["silu", "swish"],
|
SILU_ACTIVATION=activation in ["silu", "swish"],
|
||||||
IS_APC_ENABLED=current_last_idx is not None,
|
IS_APC_ENABLED=block_idx_last_scheduled_token is not None,
|
||||||
USE_PAD_SLOT=pad_slot_id is not None,
|
USE_PAD_SLOT=pad_slot_id is not None,
|
||||||
NP2_STATELEN=np2_statelen,
|
NP2_STATELEN=np2_statelen,
|
||||||
# launch_cooperative_grid=True
|
# launch_cooperative_grid=True
|
||||||
@ -756,7 +756,7 @@ def _causal_conv1d_update_kernel(
|
|||||||
conv_state_indices_ptr,
|
conv_state_indices_ptr,
|
||||||
num_accepted_tokens_ptr,
|
num_accepted_tokens_ptr,
|
||||||
query_start_loc_ptr, # (batch + 1)
|
query_start_loc_ptr, # (batch + 1)
|
||||||
current_last_idx, # (batch,)
|
block_idx_last_scheduled_token, # (batch,)
|
||||||
initial_state_idx, # (batch,)
|
initial_state_idx, # (batch,)
|
||||||
o_ptr, # (batch, dim, seqlen)
|
o_ptr, # (batch, dim, seqlen)
|
||||||
# Matrix dimensions
|
# Matrix dimensions
|
||||||
@ -802,7 +802,7 @@ def _causal_conv1d_update_kernel(
|
|||||||
if IS_APC_ENABLED:
|
if IS_APC_ENABLED:
|
||||||
# Get the state from the initial_state_idx
|
# Get the state from the initial_state_idx
|
||||||
conv_state_init = tl.load(initial_state_idx + idx_seq)
|
conv_state_init = tl.load(initial_state_idx + idx_seq)
|
||||||
current_last_index = tl.load(current_last_idx + idx_seq)
|
current_last_index = tl.load(block_idx_last_scheduled_token + idx_seq)
|
||||||
else:
|
else:
|
||||||
conv_state_init = 0
|
conv_state_init = 0
|
||||||
current_last_index = 0
|
current_last_index = 0
|
||||||
@ -1078,7 +1078,7 @@ def causal_conv1d_update(
|
|||||||
query_start_loc: Optional[torch.Tensor] = None,
|
query_start_loc: Optional[torch.Tensor] = None,
|
||||||
max_query_len: int = -1,
|
max_query_len: int = -1,
|
||||||
pad_slot_id: int = PAD_SLOT_ID,
|
pad_slot_id: int = PAD_SLOT_ID,
|
||||||
current_last_idx: Optional[torch.Tensor] = None,
|
block_idx_last_scheduled_token: Optional[torch.Tensor] = None,
|
||||||
initial_state_idx: Optional[torch.Tensor] = None,
|
initial_state_idx: Optional[torch.Tensor] = None,
|
||||||
validate_data=False,
|
validate_data=False,
|
||||||
):
|
):
|
||||||
@ -1097,7 +1097,7 @@ def causal_conv1d_update(
|
|||||||
If not None, the conv_state is a larger tensor along the batch dim,
|
If not None, the conv_state is a larger tensor along the batch dim,
|
||||||
and we are selecting the batch coords specified by conv_state_indices.
|
and we are selecting the batch coords specified by conv_state_indices.
|
||||||
Useful for a continuous batching scenario.
|
Useful for a continuous batching scenario.
|
||||||
current_last_idx: (batch,), dtype int32
|
block_idx_last_scheduled_token: (batch,), dtype int32
|
||||||
The pointer into conv_state_indices, where the last cache block to be filled is located.
|
The pointer into conv_state_indices, where the last cache block to be filled is located.
|
||||||
initial_state_idx: (batch,), dtype int32
|
initial_state_idx: (batch,), dtype int32
|
||||||
The pointer into conv_state_indices, where the cache block containing the initial state is located.
|
The pointer into conv_state_indices, where the cache block containing the initial state is located.
|
||||||
@ -1201,7 +1201,7 @@ def causal_conv1d_update(
|
|||||||
conv_state_indices,
|
conv_state_indices,
|
||||||
num_accepted_tokens,
|
num_accepted_tokens,
|
||||||
query_start_loc,
|
query_start_loc,
|
||||||
current_last_idx,
|
block_idx_last_scheduled_token,
|
||||||
initial_state_idx,
|
initial_state_idx,
|
||||||
out,
|
out,
|
||||||
# Matrix dimensions
|
# Matrix dimensions
|
||||||
@ -1230,7 +1230,7 @@ def causal_conv1d_update(
|
|||||||
KERNEL_WIDTH=width,
|
KERNEL_WIDTH=width,
|
||||||
SILU_ACTIVATION=activation in ["silu", "swish"],
|
SILU_ACTIVATION=activation in ["silu", "swish"],
|
||||||
IS_VARLEN=query_start_loc is not None,
|
IS_VARLEN=query_start_loc is not None,
|
||||||
IS_APC_ENABLED=current_last_idx is not None,
|
IS_APC_ENABLED=block_idx_last_scheduled_token is not None,
|
||||||
IS_SPEC_DECODING=num_accepted_tokens is not None,
|
IS_SPEC_DECODING=num_accepted_tokens is not None,
|
||||||
NP2_STATELEN=np2_statelen,
|
NP2_STATELEN=np2_statelen,
|
||||||
USE_PAD_SLOT=pad_slot_id is not None,
|
USE_PAD_SLOT=pad_slot_id is not None,
|
||||||
|
|||||||
@ -122,11 +122,10 @@ class Mamba2AttentionMetadata:
|
|||||||
last_chunk_indices_p: Optional[torch.Tensor]
|
last_chunk_indices_p: Optional[torch.Tensor]
|
||||||
|
|
||||||
state_indices_tensor: torch.Tensor # shape: [batch,]
|
state_indices_tensor: torch.Tensor # shape: [batch,]
|
||||||
current_last_idx: torch.Tensor
|
block_idx_last_scheduled_token: torch.Tensor # shape: [batch,]
|
||||||
current_first_idx_p: torch.Tensor
|
block_idx_first_scheduled_token_p: torch.Tensor # shape: [batch,]
|
||||||
last_state_idx: torch.Tensor
|
block_idx_last_computed_token: torch.Tensor # shape: [batch,]
|
||||||
context_lens_p: torch.Tensor
|
num_computed_tokens_p: torch.Tensor # shape: [batch,]
|
||||||
last_computed_offset_p: torch.Tensor
|
|
||||||
|
|
||||||
# The following attributes are for triton implementation of causal_conv1d
|
# The following attributes are for triton implementation of causal_conv1d
|
||||||
nums_dict: Optional[dict] = None
|
nums_dict: Optional[dict] = None
|
||||||
@ -160,12 +159,12 @@ class Mamba2AttentionMetadataBuilder(
|
|||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
device=device,
|
device=device,
|
||||||
)
|
)
|
||||||
self.current_last_idx = torch.empty(
|
self.block_idx_last_scheduled_token = torch.empty(
|
||||||
(self.decode_cudagraph_max_bs,),
|
(self.decode_cudagraph_max_bs,),
|
||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
device=device,
|
device=device,
|
||||||
)
|
)
|
||||||
self.last_state_idx = torch.empty(
|
self.block_idx_last_computed_token = torch.empty(
|
||||||
(self.decode_cudagraph_max_bs,),
|
(self.decode_cudagraph_max_bs,),
|
||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
device=device,
|
device=device,
|
||||||
@ -192,43 +191,38 @@ class Mamba2AttentionMetadataBuilder(
|
|||||||
# for causal_conv1d
|
# for causal_conv1d
|
||||||
nums_dict, batch_ptr, token_chunk_offset_ptr = None, None, None
|
nums_dict, batch_ptr, token_chunk_offset_ptr = None, None, None
|
||||||
|
|
||||||
context_lens, context_lens_p = None, None
|
num_computed_tokens, num_computed_tokens_p = None, None
|
||||||
current_first_idx, current_first_idx_p = None, None
|
block_idx_first_scheduled_token = None
|
||||||
last_computed_offset, last_computed_offset_p = None, None
|
block_idx_first_scheduled_token_p = None
|
||||||
|
|
||||||
if self.vllm_config.cache_config.enable_prefix_caching:
|
if self.vllm_config.cache_config.enable_prefix_caching:
|
||||||
# Return a tensor of shape (#requests, #max blocks)
|
# Return a tensor of shape (#requests, #max blocks)
|
||||||
state_indices_tensor = common_attn_metadata.block_table_tensor
|
state_indices_tensor = common_attn_metadata.block_table_tensor
|
||||||
|
|
||||||
# Additional cache-related varaiables:
|
# Additional cache-related varaiables:
|
||||||
mamba_block_size = self.kv_cache_spec.block_size
|
mamba_block_size = self.kv_cache_spec.block_size
|
||||||
seq_lens_pending = (
|
num_computed_tokens = common_attn_metadata.num_computed_tokens_cpu.to(
|
||||||
torch.roll(common_attn_metadata.query_start_loc, -1, -1)
|
self.device
|
||||||
- common_attn_metadata.query_start_loc
|
)
|
||||||
)[:-1]
|
# Block index of the last computed token
|
||||||
context_lens = common_attn_metadata.seq_lens - seq_lens_pending
|
block_idx_last_computed_token = (
|
||||||
last_computed_offset = context_lens % mamba_block_size
|
cdiv(num_computed_tokens, mamba_block_size) - 1
|
||||||
# Indices: last_computed <= current_first <= current_last
|
)
|
||||||
# Cases:
|
# which is <= block index for the first scheduled token
|
||||||
# last_computed == current_first if last state was partially
|
block_idx_first_scheduled_token = (
|
||||||
# computed and needs to be updated
|
cdiv(num_computed_tokens + 1, mamba_block_size) - 1
|
||||||
# current_first == current_last if no block crossing occurs, and
|
)
|
||||||
# only one state will be stored
|
# which is <= block index of the last scheduled token
|
||||||
# 0th based indexing leads to "-1" -> e.g. 16 computed -> state[15]:
|
block_idx_last_scheduled_token = (
|
||||||
current_last_idx = (
|
cdiv(common_attn_metadata.seq_lens, mamba_block_size) - 1
|
||||||
cdiv(context_lens + seq_lens_pending, mamba_block_size) - 1
|
|
||||||
)
|
)
|
||||||
current_first_idx = cdiv(context_lens + 1, mamba_block_size) - 1
|
|
||||||
last_state_idx = cdiv(context_lens, mamba_block_size) - 1
|
|
||||||
# -1 in case it's non-computed and causes later issues with indexing
|
# -1 in case it's non-computed and causes later issues with indexing
|
||||||
last_state_idx = last_state_idx.clamp(min=0)
|
block_idx_last_computed_token = block_idx_last_computed_token.clamp(min=0)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# Always return just a single block per each request:
|
# Always return just a single block per each request:
|
||||||
state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0]
|
state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0]
|
||||||
# Additional cache-related varaiables:
|
# Additional cache-related varaiables:
|
||||||
current_last_idx = None
|
block_idx_last_scheduled_token = None
|
||||||
last_state_idx = None
|
block_idx_last_computed_token = None
|
||||||
|
|
||||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
|
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
|
||||||
split_decodes_and_prefills(
|
split_decodes_and_prefills(
|
||||||
@ -256,18 +250,15 @@ class Mamba2AttentionMetadataBuilder(
|
|||||||
)
|
)
|
||||||
|
|
||||||
if self.vllm_config.cache_config.enable_prefix_caching:
|
if self.vllm_config.cache_config.enable_prefix_caching:
|
||||||
assert context_lens is not None
|
assert num_computed_tokens is not None
|
||||||
context_lens_p = context_lens[num_reqs - num_prefills : num_reqs]
|
num_computed_tokens_p = num_computed_tokens[
|
||||||
assert last_computed_offset is not None
|
|
||||||
last_computed_offset_p = last_computed_offset[
|
|
||||||
num_reqs - num_prefills : num_reqs
|
num_reqs - num_prefills : num_reqs
|
||||||
]
|
]
|
||||||
assert current_first_idx is not None
|
assert block_idx_first_scheduled_token is not None
|
||||||
current_first_idx_p = current_first_idx[
|
block_idx_first_scheduled_token_p = block_idx_first_scheduled_token[
|
||||||
num_reqs - num_prefills : num_reqs
|
num_reqs - num_prefills : num_reqs
|
||||||
]
|
]
|
||||||
|
num_computed_tokens_p_cpu = common_attn_metadata.num_computed_tokens_cpu[
|
||||||
num_computed_tokens_p = common_attn_metadata.num_computed_tokens_cpu[
|
|
||||||
num_reqs - num_prefills : num_reqs
|
num_reqs - num_prefills : num_reqs
|
||||||
]
|
]
|
||||||
query_start_loc_p_cpu = (
|
query_start_loc_p_cpu = (
|
||||||
@ -290,7 +281,7 @@ class Mamba2AttentionMetadataBuilder(
|
|||||||
last_chunk_indices = []
|
last_chunk_indices = []
|
||||||
seqlen_pos = 0
|
seqlen_pos = 0
|
||||||
for req_idx in range(num_prefills):
|
for req_idx in range(num_prefills):
|
||||||
this_num_computed = num_computed_tokens_p[req_idx].item()
|
this_num_computed = num_computed_tokens_p_cpu[req_idx].item()
|
||||||
this_new_tokens = (
|
this_new_tokens = (
|
||||||
query_start_loc_p_cpu[req_idx + 1].item()
|
query_start_loc_p_cpu[req_idx + 1].item()
|
||||||
- query_start_loc_p_cpu[req_idx].item()
|
- query_start_loc_p_cpu[req_idx].item()
|
||||||
@ -338,7 +329,10 @@ class Mamba2AttentionMetadataBuilder(
|
|||||||
compute_causal_conv1d_metadata(query_start_loc_p)
|
compute_causal_conv1d_metadata(query_start_loc_p)
|
||||||
)
|
)
|
||||||
|
|
||||||
elif num_decodes <= self.decode_cudagraph_max_bs:
|
elif (
|
||||||
|
num_decodes <= self.decode_cudagraph_max_bs
|
||||||
|
and self.compilation_config.full_cuda_graph
|
||||||
|
):
|
||||||
# Pad state tensor for CUDA graph
|
# Pad state tensor for CUDA graph
|
||||||
num_input_tokens = self.vllm_config.pad_for_cudagraph(num_decodes)
|
num_input_tokens = self.vllm_config.pad_for_cudagraph(num_decodes)
|
||||||
self.state_indices_tensor[:num_decodes].copy_(
|
self.state_indices_tensor[:num_decodes].copy_(
|
||||||
@ -348,17 +342,21 @@ class Mamba2AttentionMetadataBuilder(
|
|||||||
state_indices_tensor[num_decodes:] = PAD_SLOT_ID
|
state_indices_tensor[num_decodes:] = PAD_SLOT_ID
|
||||||
|
|
||||||
if self.vllm_config.cache_config.enable_prefix_caching:
|
if self.vllm_config.cache_config.enable_prefix_caching:
|
||||||
self.current_last_idx[:num_decodes].copy_(
|
self.block_idx_last_scheduled_token[:num_decodes].copy_(
|
||||||
current_last_idx, non_blocking=True
|
block_idx_last_scheduled_token, non_blocking=True
|
||||||
)
|
)
|
||||||
current_last_idx = self.current_last_idx[:num_input_tokens]
|
block_idx_last_scheduled_token = self.block_idx_last_scheduled_token[
|
||||||
current_last_idx[num_decodes:] = 0
|
:num_input_tokens
|
||||||
|
]
|
||||||
|
block_idx_last_scheduled_token[num_decodes:] = 0
|
||||||
|
|
||||||
self.last_state_idx[:num_decodes].copy_(
|
self.block_idx_last_computed_token[:num_decodes].copy_(
|
||||||
last_state_idx, non_blocking=True
|
block_idx_last_computed_token, non_blocking=True
|
||||||
)
|
)
|
||||||
last_state_idx = self.last_state_idx[:num_input_tokens]
|
block_idx_last_computed_token = self.block_idx_last_computed_token[
|
||||||
last_state_idx[num_decodes:] = 0
|
:num_input_tokens
|
||||||
|
]
|
||||||
|
block_idx_last_computed_token[num_decodes:] = 0
|
||||||
|
|
||||||
attn_metadata = Mamba2AttentionMetadata(
|
attn_metadata = Mamba2AttentionMetadata(
|
||||||
num_prefills=num_prefills,
|
num_prefills=num_prefills,
|
||||||
@ -377,10 +375,9 @@ class Mamba2AttentionMetadataBuilder(
|
|||||||
nums_dict=nums_dict,
|
nums_dict=nums_dict,
|
||||||
batch_ptr=batch_ptr,
|
batch_ptr=batch_ptr,
|
||||||
token_chunk_offset_ptr=token_chunk_offset_ptr,
|
token_chunk_offset_ptr=token_chunk_offset_ptr,
|
||||||
current_last_idx=current_last_idx,
|
block_idx_last_scheduled_token=block_idx_last_scheduled_token,
|
||||||
current_first_idx_p=current_first_idx_p,
|
block_idx_first_scheduled_token_p=block_idx_first_scheduled_token_p,
|
||||||
last_state_idx=last_state_idx,
|
block_idx_last_computed_token=block_idx_last_computed_token,
|
||||||
context_lens_p=context_lens_p,
|
num_computed_tokens_p=num_computed_tokens_p,
|
||||||
last_computed_offset_p=last_computed_offset_p,
|
|
||||||
)
|
)
|
||||||
return attn_metadata
|
return attn_metadata
|
||||||
|
|||||||
@ -584,8 +584,7 @@ class MambaManager(SingleTypeKVCacheManager):
|
|||||||
# hit_length = len(hit_blocks_other_attn[0])
|
# hit_length = len(hit_blocks_other_attn[0])
|
||||||
# * self.other_block_size
|
# * self.other_block_size
|
||||||
# so we insert dummy blocks at the beginning:
|
# so we insert dummy blocks at the beginning:
|
||||||
if i > 0:
|
computed.extend([block_pool.null_block] * i)
|
||||||
computed.extend([block_pool.null_block] * i)
|
|
||||||
computed.append(cached)
|
computed.append(cached)
|
||||||
break # we just need the last match - early stopping
|
break # we just need the last match - early stopping
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user