mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 07:55:01 +08:00
Fix test_mamba_ssm_ssd.py due to missing _query_start_loc_to_chunk_indices_offsets (#25995)
Signed-off-by: Huamin Li <3ericli@gmail.com>
This commit is contained in:
parent
5234dc7451
commit
c36f0aa300
@ -10,7 +10,7 @@ from vllm.model_executor.layers.mamba.ops.ssd_combined import (
|
||||
mamba_chunk_scan_combined_varlen)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.v1.attention.backends.mamba2_attn import (
|
||||
_query_start_loc_to_chunk_indices_offsets)
|
||||
compute_varlen_chunk_metadata)
|
||||
|
||||
# Added by the IBM Team, 2024
|
||||
|
||||
@ -225,13 +225,9 @@ def test_mamba_chunk_scan_single_example(d_head, n_heads, seq_len_chunk_size,
|
||||
Y_min, final_state_min = ssd_minimal_discrete(X * dt.unsqueeze(-1), A * dt,
|
||||
B, C, chunk_size)
|
||||
|
||||
cu_seqlens = torch.tensor((0, seqlen), device='cuda').cumsum(dim=0)
|
||||
seq_idx = torch.zeros(seqlen, dtype=torch.int32, device=cu_seqlens.device)
|
||||
|
||||
chunk_indices, chunk_offsets = \
|
||||
_query_start_loc_to_chunk_indices_offsets(
|
||||
cu_seqlens, chunk_size, cu_seqlens[-1])
|
||||
|
||||
cu_seqlens = torch.tensor((0, seqlen), device="cuda").cumsum(dim=0)
|
||||
cu_chunk_seqlens, last_chunk_indices, seq_idx_chunks = (
|
||||
compute_varlen_chunk_metadata(cu_seqlens, chunk_size))
|
||||
# varlen has implicit batch=1
|
||||
X = X.squeeze(0)
|
||||
dt = dt.squeeze(0)
|
||||
@ -239,18 +235,20 @@ def test_mamba_chunk_scan_single_example(d_head, n_heads, seq_len_chunk_size,
|
||||
B = B.squeeze(0)
|
||||
C = C.squeeze(0)
|
||||
Y = torch.empty_like(X)
|
||||
final_state = mamba_chunk_scan_combined_varlen(X,
|
||||
dt,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
chunk_size,
|
||||
D=None,
|
||||
cu_seqlens=cu_seqlens,
|
||||
seq_idx=seq_idx,
|
||||
chunk_indices=chunk_indices,
|
||||
chunk_offsets=chunk_offsets,
|
||||
out=Y)
|
||||
final_state = mamba_chunk_scan_combined_varlen(
|
||||
X,
|
||||
dt,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
chunk_size,
|
||||
cu_seqlens=cu_seqlens.to(torch.int32),
|
||||
cu_chunk_seqlens=cu_chunk_seqlens,
|
||||
last_chunk_indices=last_chunk_indices,
|
||||
seq_idx=seq_idx_chunks,
|
||||
out=Y,
|
||||
D=None,
|
||||
)
|
||||
|
||||
# just test the last in sequence
|
||||
torch.testing.assert_close(Y[-1], Y_min[0, -1], atol=atol, rtol=rtol)
|
||||
@ -312,14 +310,13 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases,
|
||||
exhausted: dict = {} # map: eg -> boolean indicating example is exhausted
|
||||
|
||||
states = None
|
||||
for Y_min, cu_seqlens, seq_idx, (
|
||||
for Y_min, cu_seqlens, _token_seq_idx, (
|
||||
A, dt, X, B, C) in generate_continuous_batched_examples(
|
||||
cases, num_examples, seqlen, last_taken, exhausted, n_heads,
|
||||
d_head, itype):
|
||||
|
||||
chunk_indices, chunk_offsets = \
|
||||
_query_start_loc_to_chunk_indices_offsets(
|
||||
cu_seqlens, chunk_size, cu_seqlens[-1])
|
||||
cu_chunk_seqlens, last_chunk_indices, seq_idx_chunks = (
|
||||
compute_varlen_chunk_metadata(cu_seqlens, chunk_size))
|
||||
|
||||
Y = torch.empty_like(X)
|
||||
new_states = mamba_chunk_scan_combined_varlen(
|
||||
@ -329,13 +326,13 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases,
|
||||
B,
|
||||
C,
|
||||
chunk_size,
|
||||
D=None,
|
||||
cu_seqlens=cu_seqlens,
|
||||
seq_idx=seq_idx,
|
||||
chunk_indices=chunk_indices,
|
||||
chunk_offsets=chunk_offsets,
|
||||
initial_states=states,
|
||||
cu_seqlens=cu_seqlens.to(torch.int32),
|
||||
cu_chunk_seqlens=cu_chunk_seqlens,
|
||||
last_chunk_indices=last_chunk_indices,
|
||||
seq_idx=seq_idx_chunks,
|
||||
out=Y,
|
||||
D=None,
|
||||
initial_states=states,
|
||||
)
|
||||
|
||||
# just test the last in sequence
|
||||
@ -403,9 +400,8 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens):
|
||||
device = X.device
|
||||
|
||||
## full seqlen computation
|
||||
chunk_indices, chunk_offsets = \
|
||||
_query_start_loc_to_chunk_indices_offsets(
|
||||
cu_seqlens, chunk_size, cu_seqlens[-1])
|
||||
cu_chunk_seqlens, last_chunk_indices, seq_idx_chunks = (
|
||||
compute_varlen_chunk_metadata(cu_seqlens, chunk_size))
|
||||
Y_ref = torch.empty_like(X)
|
||||
state_ref = mamba_chunk_scan_combined_varlen(
|
||||
X,
|
||||
@ -414,13 +410,13 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens):
|
||||
B,
|
||||
C,
|
||||
chunk_size,
|
||||
D=None,
|
||||
cu_seqlens=cu_seqlens,
|
||||
seq_idx=seq_idx,
|
||||
chunk_indices=chunk_indices,
|
||||
chunk_offsets=chunk_offsets,
|
||||
initial_states=None,
|
||||
cu_seqlens=cu_seqlens.to(torch.int32),
|
||||
cu_chunk_seqlens=cu_chunk_seqlens,
|
||||
last_chunk_indices=last_chunk_indices,
|
||||
seq_idx=seq_idx_chunks,
|
||||
out=Y_ref,
|
||||
D=None,
|
||||
initial_states=None,
|
||||
)
|
||||
|
||||
## chunked seqlen computation
|
||||
@ -431,10 +427,6 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens):
|
||||
torch.cumsum(chunked_seqlens, dim=0)
|
||||
],
|
||||
dim=0)
|
||||
chunked_seq_idx = torch.repeat_interleave(
|
||||
torch.arange(len(chunked_seqlens), device=device),
|
||||
chunked_seqlens,
|
||||
output_size=chunked_cu_seqlens[-1]).to(torch.int32)
|
||||
chunked_input_seq_len = chunked_cu_seqlens[-1]
|
||||
X_chunked = torch.zeros_like(X)[:chunked_input_seq_len, ...]
|
||||
dt_chunked = torch.zeros_like(dt)[:chunked_input_seq_len, ...]
|
||||
@ -450,9 +442,8 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens):
|
||||
C_chunked[chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1], ...] = chunk_f(C, i) # noqa: E501
|
||||
# fmt: on
|
||||
|
||||
chunk_indices, chunk_offsets = \
|
||||
_query_start_loc_to_chunk_indices_offsets(
|
||||
chunked_cu_seqlens, chunk_size, chunked_cu_seqlens[-1])
|
||||
cu_chunk_seqlens, last_chunk_indices, seq_idx_chunks = (
|
||||
compute_varlen_chunk_metadata(chunked_cu_seqlens, chunk_size))
|
||||
Y_partial = torch.empty_like(X_chunked)
|
||||
partial_state = mamba_chunk_scan_combined_varlen(
|
||||
X_chunked,
|
||||
@ -461,13 +452,13 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens):
|
||||
B_chunked,
|
||||
C_chunked,
|
||||
chunk_size,
|
||||
D=None,
|
||||
cu_seqlens=chunked_cu_seqlens,
|
||||
seq_idx=chunked_seq_idx,
|
||||
chunk_indices=chunk_indices,
|
||||
chunk_offsets=chunk_offsets,
|
||||
initial_states=None,
|
||||
cu_seqlens=chunked_cu_seqlens.to(torch.int32),
|
||||
cu_chunk_seqlens=cu_chunk_seqlens,
|
||||
last_chunk_indices=last_chunk_indices,
|
||||
seq_idx=seq_idx_chunks,
|
||||
out=Y_partial,
|
||||
D=None,
|
||||
initial_states=None,
|
||||
)
|
||||
|
||||
# remaining chunk
|
||||
@ -477,10 +468,6 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens):
|
||||
torch.cumsum(remaining_chunked_seqlens, dim=0)
|
||||
],
|
||||
dim=0)
|
||||
remaining_chunked_seq_idx = torch.repeat_interleave(
|
||||
torch.arange(len(remaining_chunked_seqlens), device=device),
|
||||
remaining_chunked_seqlens,
|
||||
output_size=remaining_chunked_cu_seqlens[-1]).to(torch.int32)
|
||||
remaining_chunked_input_seq_len = remaining_chunked_cu_seqlens[-1]
|
||||
# fmt: off
|
||||
remaining_X_chunked = torch.zeros_like(X)[:remaining_chunked_input_seq_len, ...] # noqa: E501
|
||||
@ -509,11 +496,9 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens):
|
||||
assert concat_batch_f(B_chunked, remaining_B_chunked).equal(B)
|
||||
assert concat_batch_f(C_chunked, remaining_C_chunked).equal(C)
|
||||
|
||||
chunk_indices, chunk_offsets = \
|
||||
_query_start_loc_to_chunk_indices_offsets(
|
||||
remaining_chunked_cu_seqlens,
|
||||
chunk_size,
|
||||
remaining_chunked_cu_seqlens[-1])
|
||||
cu_chunk_seqlens, last_chunk_indices, seq_idx_chunks = (
|
||||
compute_varlen_chunk_metadata(remaining_chunked_cu_seqlens,
|
||||
chunk_size))
|
||||
|
||||
Y_chunked = torch.empty_like(remaining_X_chunked)
|
||||
state_chunked = mamba_chunk_scan_combined_varlen(
|
||||
@ -523,13 +508,13 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens):
|
||||
remaining_B_chunked,
|
||||
remaining_C_chunked,
|
||||
chunk_size,
|
||||
D=None,
|
||||
cu_seqlens=remaining_chunked_cu_seqlens,
|
||||
seq_idx=remaining_chunked_seq_idx,
|
||||
chunk_indices=chunk_indices,
|
||||
chunk_offsets=chunk_offsets,
|
||||
initial_states=partial_state,
|
||||
cu_seqlens=remaining_chunked_cu_seqlens.to(torch.int32),
|
||||
cu_chunk_seqlens=cu_chunk_seqlens,
|
||||
last_chunk_indices=last_chunk_indices,
|
||||
seq_idx=seq_idx_chunks,
|
||||
out=Y_chunked,
|
||||
D=None,
|
||||
initial_states=partial_state,
|
||||
)
|
||||
Y = concat_batch_f(Y_partial, Y_chunked)
|
||||
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import itertools
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
@ -17,6 +18,75 @@ from vllm.v1.attention.backends.utils import (PAD_SLOT_ID,
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
|
||||
|
||||
def compute_varlen_chunk_metadata(
|
||||
query_start_loc: torch.Tensor,
|
||||
chunk_size: int,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Build chunk-aligned, variable-length metadata used by Mamba2 SSD kernels.
|
||||
|
||||
Given per-sequence cumulative token starts `query_start_loc` of shape [B+1]
|
||||
and a physical `chunk_size`, returns three tensors on the same device:
|
||||
- cu_chunk_seqlens: (nchunks+1,) int32 exclusive prefix-sum of
|
||||
logical-chunk lengths (each logical chunk never crosses a sequence or
|
||||
physical-chunk boundary).
|
||||
- last_chunk_indices: (B,) int32 index of the last logical chunk
|
||||
for each sequence (=-1 for empty sequences).
|
||||
- seq_idx_chunks: (nchunks,) int32 sequence index for each logical
|
||||
chunk in order.
|
||||
|
||||
This is intentionally lightweight and CPU-side; it mirrors the metadata
|
||||
produced by the V1 Mamba2 meta-data builder and is exported so tests
|
||||
(and other callers) can avoid duplicating the logic.
|
||||
"""
|
||||
assert query_start_loc.ndim == 1, "query_start_loc must be 1-D [B+1]"
|
||||
assert int(query_start_loc[0].item()) == 0, "query_start_loc[0] must be 0"
|
||||
device = query_start_loc.device
|
||||
|
||||
qsl64 = query_start_loc.to(torch.int64)
|
||||
starts = qsl64[:-1].tolist()
|
||||
ends = qsl64[1:].tolist()
|
||||
total = int(qsl64[-1].item())
|
||||
|
||||
chunk_lens: list[int] = []
|
||||
seq_idx_chunks: list[int] = []
|
||||
last_chunk_indices: list[int] = [-1] * len(starts)
|
||||
|
||||
for b, (s, e) in enumerate(zip(starts, ends)):
|
||||
if e <= s:
|
||||
# empty sequence
|
||||
continue
|
||||
pos = s
|
||||
while pos < e:
|
||||
# split at both sequence boundaries and physical chunk boundaries
|
||||
room = chunk_size - (pos % chunk_size)
|
||||
take = min(room, e - pos)
|
||||
chunk_lens.append(int(take))
|
||||
seq_idx_chunks.append(b)
|
||||
last_chunk_indices[b] = len(chunk_lens) - 1
|
||||
pos += take
|
||||
|
||||
# Exclusive prefix sum over logical-chunk lengths
|
||||
if chunk_lens:
|
||||
cu_chunk_seqlens = torch.tensor([0] +
|
||||
list(itertools.accumulate(chunk_lens)),
|
||||
device=device,
|
||||
dtype=torch.int32)
|
||||
# Final boundary must equal total tokens
|
||||
assert int(cu_chunk_seqlens[-1].item()) == total
|
||||
else:
|
||||
cu_chunk_seqlens = torch.tensor([0], device=device, dtype=torch.int32)
|
||||
|
||||
last_chunk_indices_t = (torch.tensor(
|
||||
last_chunk_indices, device=device, dtype=torch.int32)
|
||||
if len(starts) > 0 else torch.empty(
|
||||
(0, ), device=device, dtype=torch.int32))
|
||||
seq_idx_chunks_t = torch.tensor(seq_idx_chunks,
|
||||
device=device,
|
||||
dtype=torch.int32)
|
||||
return cu_chunk_seqlens, last_chunk_indices_t, seq_idx_chunks_t
|
||||
|
||||
|
||||
class Mamba2AttentionBackend(AttentionBackend):
|
||||
|
||||
@staticmethod
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user