[Bugfix] Fix mamba2 prefill chunking (#23279)

Signed-off-by: Tomer Asida <57313761+tomeras91@users.noreply.github.com>
Signed-off-by: tomeras91 <57313761+tomeras91@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
tomeras91 2025-09-08 14:42:41 +03:00 committed by GitHub
parent 5e537f45b4
commit e041314184
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 349 additions and 35 deletions

View File

@ -115,21 +115,27 @@ def generate_continuous_batched_examples(example_lens_by_batch,
n_heads,
d_head,
itype,
device='cuda'):
device='cuda',
return_naive_ref=True):
# this function generates a random examples of certain length
# and then cut according to "example_lens_by_batch" and feed
# them in continuous batches to the kernels
# them in continuous batches to the kernels.
# If if return_naive_ref=True, the naive torch implementation
# ssd_minimal_discrete will be used to compute and return
# reference output.
# generate the full-length example
A, dt, X, B, C = generate_random_inputs(num_examples, full_length, n_heads,
d_head, itype)
Y_min, final_state_min = ssd_minimal_discrete(X * dt.unsqueeze(-1),
A * dt,
B,
C,
block_len=full_length // 4)
if return_naive_ref:
Y_min, final_state_min = ssd_minimal_discrete(X * dt.unsqueeze(-1),
A * dt,
B,
C,
block_len=full_length //
4)
# internal function that outputs a cont batch of examples
# given a tuple of lengths for each example in the batch
@ -179,7 +185,8 @@ def generate_continuous_batched_examples(example_lens_by_batch,
IND_S = [x % full_length for x in IND_E]
IND_E = [end_boundary(x + y) for x, y in zip(IND_S, spec)]
yield ([Y_min[s, IND_S[s]:IND_E[s]] for s in range(num_examples)],
yield ([Y_min[s, IND_S[s]:IND_E[s]]
for s in range(num_examples)] if return_naive_ref else None,
cu_seqlens, seq_idx.unsqueeze(0), (A, dt2, X2, B2, C2))
@ -324,3 +331,213 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases,
if clear:
states[i].fill_(0.)
exhausted[i] = False
@pytest.mark.parametrize("chunk_size", [8, 256])
@pytest.mark.parametrize("seqlens", [
(16, 2, 8, 13),
(270, 88, 212, 203),
(16, 20),
])
def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens):
# This test verifies the correctness of the chunked prefill implementation
# in the mamba2 ssd kernels, by comparing concatenation (in the sequence
# dimension) of chunked results with the full sequence result.
# It is different from test_mamba_chunk_scan_cont_batch by:
# 1. Not using the naive torch implementaion (ssd_minimal_discrete) to get
# reference outputs. Instead, it compares chunked kernel outputs to full
# sequence kernel outputs. This is the most straightforward way to
# assert chunked prefill correctness.
# 2. It focuses on cases where sequences change in the middle of mamba
# chunks, and not necessarily on chunk boundaries.
max_seqlen = max(seqlens)
# This test can have larger error for longer sequences
if max_seqlen > 256:
atol, rtol = 1e-2, 5e-3
else:
atol, rtol = 5e-3, 5e-3
num_sequences = len(seqlens)
n_heads = 16
d_head = 64
itype = torch.float32
# hold state during the cutting process so we know if an
# example has been exhausted and needs to cycle
last_taken: dict = {} # map: eg -> pointer to last taken sample
exhausted: dict = {} # map: eg -> boolean indicating example is exhausted
_, cu_seqlens, seq_idx, (A, dt, X, B, C) = next(
generate_continuous_batched_examples([seqlens],
num_sequences,
max_seqlen,
last_taken,
exhausted,
n_heads,
d_head,
itype,
return_naive_ref=False))
seqlens = torch.tensor(seqlens, dtype=torch.int32, device=X.device)
device = X.device
## full seqlen computation
chunk_indices, chunk_offsets = \
_query_start_loc_to_chunk_indices_offsets(
cu_seqlens, chunk_size, cu_seqlens[-1])
Y_ref = torch.empty_like(X)
state_ref = mamba_chunk_scan_combined(
X,
dt,
A,
B,
C,
chunk_size,
D=None,
cu_seqlens=cu_seqlens,
seq_idx=seq_idx,
chunk_indices=chunk_indices,
chunk_offsets=chunk_offsets,
return_varlen_states=True,
initial_states=None,
out=Y_ref,
)
## chunked seqlen computation
# first chunk
chunked_seqlens = seqlens // 2
chunked_cu_seqlens = torch.cat([
torch.tensor([0], device=device),
torch.cumsum(chunked_seqlens, dim=0)
],
dim=0)
chunked_seq_idx = torch.repeat_interleave(
torch.arange(len(chunked_seqlens), device=device),
chunked_seqlens,
output_size=chunked_cu_seqlens[-1]).unsqueeze(0).to(torch.int32)
chunked_input_seq_len = chunked_cu_seqlens[-1]
X_chunked = torch.zeros_like(X)[:, :chunked_input_seq_len, ...]
dt_chunked = torch.zeros_like(dt)[:, :chunked_input_seq_len, ...]
B_chunked = torch.zeros_like(B)[:, :chunked_input_seq_len, ...]
C_chunked = torch.zeros_like(C)[:, :chunked_input_seq_len, ...]
for i in range(num_sequences):
# fmt: off
chunk_f = lambda x, i: x[:, cu_seqlens[i]:cu_seqlens[i] + chunked_seqlens[i], ...] # noqa: E501
X_chunked[:, chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1], ...] = chunk_f(X, i) # noqa: E501
dt_chunked[:, chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1], ...] = chunk_f(dt, i) # noqa: E501
B_chunked[:, chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1], ...] = chunk_f(B, i) # noqa: E501
C_chunked[:, chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1], ...] = chunk_f(C, i) # noqa: E501
# fmt: on
chunk_indices, chunk_offsets = \
_query_start_loc_to_chunk_indices_offsets(
chunked_cu_seqlens, chunk_size, chunked_cu_seqlens[-1])
Y_partial = torch.empty_like(X_chunked)
partial_state = mamba_chunk_scan_combined(
X_chunked,
dt_chunked,
A,
B_chunked,
C_chunked,
chunk_size,
D=None,
cu_seqlens=chunked_cu_seqlens,
seq_idx=chunked_seq_idx,
chunk_indices=chunk_indices,
chunk_offsets=chunk_offsets,
return_varlen_states=True,
initial_states=None,
out=Y_partial,
)
# remaining chunk
remaining_chunked_seqlens = seqlens - chunked_seqlens
remaining_chunked_cu_seqlens = torch.cat([
torch.tensor([0], device=device),
torch.cumsum(remaining_chunked_seqlens, dim=0)
],
dim=0)
remaining_chunked_seq_idx = torch.repeat_interleave(
torch.arange(len(remaining_chunked_seqlens), device=device),
remaining_chunked_seqlens,
output_size=remaining_chunked_cu_seqlens[-1]).unsqueeze(0).to(
torch.int32)
remaining_chunked_input_seq_len = remaining_chunked_cu_seqlens[-1]
# fmt: off
remaining_X_chunked = torch.zeros_like(X)[:, :remaining_chunked_input_seq_len, ...] # noqa: E501
remaining_dt_chunked = torch.zeros_like(dt)[:, :remaining_chunked_input_seq_len, ...] # noqa: E501
remaining_B_chunked = torch.zeros_like(B)[:, :remaining_chunked_input_seq_len, ...] # noqa: E501
remaining_C_chunked = torch.zeros_like(C)[:, :remaining_chunked_input_seq_len, ...] # noqa: E501
for i in range(num_sequences):
remaining_chunk_f = lambda x, i: x[:, cu_seqlens[i] + chunked_seqlens[i]:cu_seqlens[i+1], ...] # noqa: E501
remaining_X_chunked[:, remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1], ...] = remaining_chunk_f(X, i) # noqa: E501
remaining_dt_chunked[:, remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1], ...] = remaining_chunk_f(dt, i) # noqa: E501
remaining_B_chunked[:, remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1], ...] = remaining_chunk_f(B, i) # noqa: E501
remaining_C_chunked[:, remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1], ...] = remaining_chunk_f(C, i) # noqa: E501
# assert input chunking is correct
concat_chunk_f = lambda pt1, pt2, i: torch.cat([
pt1[:,chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1],...],
pt2[:,remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1],...],
],
dim=1)
concat_batch_f = lambda pt1, pt2: torch.cat([concat_chunk_f(pt1, pt2, i) for i in range(num_sequences)], dim=1) # noqa: E501
# fmt: on
assert concat_batch_f(X_chunked, remaining_X_chunked).equal(X)
assert concat_batch_f(dt_chunked, remaining_dt_chunked).equal(dt)
assert concat_batch_f(B_chunked, remaining_B_chunked).equal(B)
assert concat_batch_f(C_chunked, remaining_C_chunked).equal(C)
chunk_indices, chunk_offsets = \
_query_start_loc_to_chunk_indices_offsets(
remaining_chunked_cu_seqlens,
chunk_size,
remaining_chunked_cu_seqlens[-1])
Y_chunked = torch.empty_like(remaining_X_chunked)
state_chunked = mamba_chunk_scan_combined(
remaining_X_chunked,
remaining_dt_chunked,
A,
remaining_B_chunked,
remaining_C_chunked,
chunk_size,
D=None,
cu_seqlens=remaining_chunked_cu_seqlens,
seq_idx=remaining_chunked_seq_idx,
chunk_indices=chunk_indices,
chunk_offsets=chunk_offsets,
return_varlen_states=True,
initial_states=partial_state,
out=Y_chunked,
)
Y = concat_batch_f(Y_partial, Y_chunked)
# kernel chunked is same as kernel overall
for i in range(num_sequences):
Y_seq = Y[:, cu_seqlens[i]:cu_seqlens[i + 1], ...]
Y_ref_seq = Y_ref[:, cu_seqlens[i]:cu_seqlens[i + 1], ...]
torch.testing.assert_close(
Y_seq[:, :chunked_seqlens[i], ...],
Y_ref_seq[:, :chunked_seqlens[i], ...],
atol=atol,
rtol=rtol,
msg=lambda x: f"seq{i} output part1 " + x) # noqa: B023
torch.testing.assert_close(
Y_seq[:, chunked_seqlens[i]:, ...],
Y_ref_seq[:, chunked_seqlens[i]:, ...],
atol=atol,
rtol=rtol,
msg=lambda x: f"seq{i} output part2 " + x) # noqa: B023
state_seq = state_chunked[i]
state_seq_ref = state_ref[i]
torch.testing.assert_close(
state_seq,
state_seq_ref,
atol=atol,
rtol=rtol,
msg=lambda x: f"seq{i} state " + x) # noqa: B023

View File

@ -289,6 +289,9 @@ def _chunk_scan_fwd_kernel(
# get the cs at the offset boundary
# - c_off == 0 is a passthrough
# - We need dA_cs at the boundary, defined by c_off - no need
# to increase pointer by pid_m (it is a constant offset,
# i.e. the same for all blocks)
dA_cs_m_boundary = tl.load(
dA_cumsum_ptr + (c_off - 1) * stride_dA_cs_csize,
mask=(((c_off - 1) > -1) and ((c_off) < chunk_size)),

View File

@ -106,21 +106,24 @@ def _mamba_chunk_scan_combined_fwd(x,
# 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries
# (middle term of factorization of off-diag blocks; A terms)
# - for handling chunked prefill, this requires i) initial_states
# ii) seq_idx and iii) is_cont_batched to be all specified.
# ii) seq_idx iii) is_cont_batched and (iv) chunk_offsets to be all specified.
# - When a new seq_idx is detected, we will stop passing the prev_state
# and switch accordingly to the init_state corresponding to the new seq_idx.
# - We will also make sure that the dA_cumsum is taken only from the start of the
# sequence (hence we need the full dA_cumsum tensor and not just the values at chunk boundaries)
# - this will ensure that states will be updated with the rightmost flushed seq_idx
# of the previous chunk. This implies that the first chunk of states is either 0
# or equal to init_states of the first example.
states, final_states = _state_passing_fwd(
rearrange(states, "... p n -> ... (p n)"),
dA_cumsum[:, :, :, -1],
dA_cumsum,
initial_states=rearrange(initial_states, "... p n -> ... (p n)")
if initial_states is not None else None,
seq_idx=seq_idx,
chunk_size=chunk_size,
out_dtype=state_dtype if state_dtype is not None else C.dtype,
is_cont_batched=cu_seqlens is not None)
is_cont_batched=cu_seqlens is not None,
chunk_offsets=chunk_offsets)
states, final_states = (rearrange(t, "... (p n) -> ... p n", n=dstate)
for t in [states, final_states])

View File

@ -31,6 +31,8 @@ def _state_passing_fwd_kernel(
dA_cs_ptr,
initstates_ptr,
seq_idx_ptr,
chunk_offsets_ptr,
chunk_meta_num,
# Matrix dimensions
dim,
nchunks,
@ -51,6 +53,7 @@ def _state_passing_fwd_kernel(
stride_dA_cs_batch,
stride_dA_cs_chunk,
stride_dA_cs_head,
stride_dA_cs_csize,
stride_initstates_batch,
stride_initstates_head,
stride_initstates_dim,
@ -66,7 +69,8 @@ def _state_passing_fwd_kernel(
pid_h = tl.program_id(axis=2)
pid_m = tl.program_id(axis=0)
states_ptr += pid_b * stride_states_batch + pid_h * stride_states_head
dA_cs_ptr += pid_b * stride_dA_cs_batch + pid_h * stride_dA_cs_head
dA_cs_ptr += pid_b * stride_dA_cs_batch + pid_h * stride_dA_cs_head + (
chunk_size - 1) * stride_dA_cs_csize
out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head
final_states_ptr += pid_b * stride_final_states_batch + pid_h * stride_final_states_head
if HAS_INITSTATES:
@ -95,35 +99,62 @@ def _state_passing_fwd_kernel(
tl.store(out_ptrs, states, mask=offs_m < dim)
out_ptrs += stride_out_chunk
seq_idx = 0
prev_seq_idx_chunk_end = 0
logical_chunk_idx = 0
for c in range(nchunks):
new_states = tl.load(states_ptrs, mask=offs_m < dim,
other=0.0).to(tl.float32)
dA_cs = tl.load(dA_cs_ptr).to(tl.float32)
scale = tl.exp(dA_cs)
scale_mask = True
if HAS_SEQ_IDX:
# - the seq to pass forward is the one that is flushed to the right
# boundary.
# - that is given by seq_idx_new below.
seq_idx_new = tl.load(seq_idx_ptr +
(min((c + 1) * chunk_size, seqlen) - 1) *
stride_seq_idx_seqlen)
# - that is given by seq_idx_chunk_end below: the sequence index at the end of the chunk.
seq_idx_chunk_end = tl.load(seq_idx_ptr + (min(
(c + 1) * chunk_size, seqlen) - 1) * stride_seq_idx_seqlen)
if HAS_INITSTATES:
if IS_CONT_BATCHED and seq_idx != seq_idx_new:
if IS_CONT_BATCHED and prev_seq_idx_chunk_end != seq_idx_chunk_end:
# this means in the current chunk the rightmost flushed seq
# has changed.
# - so we do not propagate the state from previous chunk
# - but rather we load that sequence's init state
initstates_ptrs = initstates_ptr + seq_idx_new * stride_initstates_batch
initstates_ptrs = initstates_ptr + seq_idx_chunk_end * stride_initstates_batch
# - update state with seq_idx_new's init state
states = tl.load(initstates_ptrs,
mask=offs_m < dim,
other=0.0).to(tl.float32)
else:
scale = tl.where(seq_idx_new == seq_idx, scale, 0.0)
seq_idx = seq_idx_new
# - we need to consider the cumsum only of the last sequence in the chunk
# - find its starting position (given by c_off of the logical chunk index)
# - and subtract the cumsum just before that position from the total cumsum
# - first, update the logical chunk index (add the number of sequences in the current physical chunk):
# sequence index at the start of the current chunk
seq_idx_chunk_start = tl.load(seq_idx_ptr +
min(c * chunk_size, seqlen) *
stride_seq_idx_seqlen)
logical_chunk_idx += seq_idx_chunk_end - seq_idx_chunk_start
# - load the chunk offset:
c_off = tl.load(chunk_offsets_ptr + logical_chunk_idx,
mask=logical_chunk_idx < chunk_meta_num,
other=0)
# - if offset is 0, then the sequence starts at the beginning of the chunk, and we don't need to subtract anything
if c_off > 0:
# - dA_cs_ptr currently points to the cumsum at the end of the chunk - subtract the chunk size and add the offset
dA_cs_boundary = tl.load(
dA_cs_ptr - (chunk_size - 1) * stride_dA_cs_csize +
(c_off - 1) * stride_dA_cs_csize,
mask=(c_off - 1) > -1 and c_off < chunk_size,
other=0.0)
dA_cs -= dA_cs_boundary
# - increment logical chunk index for every physical chunk
logical_chunk_idx += 1
else:
scale_mask = seq_idx_chunk_end == prev_seq_idx_chunk_end
prev_seq_idx_chunk_end = seq_idx_chunk_end
scale = tl.where(scale_mask, tl.exp(dA_cs), 0.0)
states = scale * states + new_states
if c < nchunks - 1:
tl.store(out_ptrs, states, mask=offs_m < dim)
@ -136,28 +167,36 @@ def _state_passing_fwd_kernel(
def _state_passing_fwd(
states,
dA_chunk_cumsum,
dA_cumsum,
initial_states=None,
seq_idx=None,
chunk_size=None,
out_dtype=None,
is_cont_batched=False,
chunk_offsets=None,
):
batch, nchunks, nheads, dim = states.shape
assert dA_chunk_cumsum.shape == (batch, nheads, nchunks)
if chunk_size is None:
chunk_size = dA_cumsum.shape[-1]
else:
assert chunk_size == dA_cumsum.shape[-1]
assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size)
if initial_states is not None:
if is_cont_batched:
# - if cu_seqlens is provided, then the initial states
# are used for continuous batching. In which case we
# require seq_idx to be provided
assert seq_idx is not None, ""
assert seq_idx is not None, "seq_idx must be provided for continuous batching"
# - we also need chunk_offsets to be provided, to account
# for computation of dA_cumsum from the start of the
# sequence
assert chunk_offsets is not None, "chunk_offsets must be provided for continuous batching"
else:
# - this is the regular batching case, where initial
# states are used are for each example of the batch.
assert initial_states.shape == (batch, nheads, dim)
if seq_idx is not None:
assert chunk_size is not None
seqlen = seq_idx.shape[-1]
assert seq_idx.shape == (batch, seqlen)
out_dtype = states.dtype if out_dtype is None else out_dtype
@ -173,13 +212,15 @@ def _state_passing_fwd(
states,
out,
final_states,
dA_chunk_cumsum,
dA_cumsum,
initial_states,
seq_idx,
chunk_offsets,
len(chunk_offsets) if chunk_offsets is not None else 0,
dim,
nchunks,
seqlen if seq_idx is not None else 0,
chunk_size if seq_idx is not None else 0,
chunk_size,
states.stride(0),
states.stride(1),
states.stride(2),
@ -191,9 +232,10 @@ def _state_passing_fwd(
final_states.stride(0),
final_states.stride(1),
final_states.stride(2),
dA_chunk_cumsum.stride(0),
dA_chunk_cumsum.stride(2),
dA_chunk_cumsum.stride(1),
dA_cumsum.stride(0),
dA_cumsum.stride(2),
dA_cumsum.stride(1),
dA_cumsum.stride(3),
*((initial_states.stride(0), initial_states.stride(1),
initial_states.stride(2)) if initial_states is not None else
(0, 0, 0)),

View File

@ -16,9 +16,58 @@ from vllm.v1.attention.backends.utils import (CommonAttentionMetadata,
from vllm.v1.kv_cache_interface import AttentionSpec
def _query_start_loc_to_chunk_indices_offsets(query_start_loc: torch.Tensor,
chunk_size: int,
total_seqlens: int):
def _query_start_loc_to_chunk_indices_offsets(
query_start_loc: torch.Tensor, chunk_size: int,
total_seqlens: int) -> tuple[torch.Tensor, torch.Tensor]:
"""
Args:
query_start_loc (torch.Tensor): 1D tensor of cumulative sequence
lengths, shape (num_seqs + 1,).
The first element should be 0. Each entry represents the starting
index of a sequence in the flattened token array.
chunk_size (int): The size of each physical mamba chunk
(number of tokens per chunk).
total_seqlens (int): The total number of tokens in the batch.
Returns:
Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
- chunk_indices (torch.Tensor): 1D tensor of indices
indicating the physical chunk for each logical chunk.
- chunk_offsets (torch.Tensor): 1D tensor of offsets
indicating the starting index of each logical chunk within
its physical chunk.
This function computes the chunk indices and offsets for the given
query_start_loc and chunk_size. Both are tensors of integers with length N,
where N is the number of logical (pseudo) chunks.
A logical chunk is a sequence of tokens that are all part of the same
sequence and are all in the same physical mamba chunk.
In other words, a logical chunk changes every time we cross a sequence
boundary or a physical mamba chunk boundary.
Logical chunks are needed to handle batched requests with initial states
(see _state_passing_fwd and _chunk_scan_fwd).
The chunk_indices tensor contains the index of the physical chunk for each
logical chunk.
The chunk_offsets tensor contains the offset (AKA starting index) of the
logical chunk in the physical chunk.
Example:
query_start_loc = [0, 5, 10]
chunk_size = 8
total_seqlens = 10
-> chunk_indices = [0, 0, 1]
-> chunk_offsets = [0, 5, 0]
In this example, we have 2 sequences, each with 5 tokens. The physical
chunk size is 8 tokens.
We have three logical chunks:
- the first logical chunk starts at token 0 in the first physical chunk
and contains all 5 tokens from the first sequence
- the second logical chunk starts at token 5 in the first physical chunk
and contains first 3 tokens from the second sequence
- the third logical chunk starts at token 0 in the second physical chunk
and contains the remaining 2 tokens from the second sequence
"""
cu_seqlens = query_start_loc[1:] # remove prepended 0