mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 16:54:59 +08:00
Fix test_mamba_ssm_ssd.py due to missing _query_start_loc_to_chunk_indices_offsets (#25995)
Signed-off-by: Huamin Li <3ericli@gmail.com>
This commit is contained in:
parent
5234dc7451
commit
c36f0aa300
@ -10,7 +10,7 @@ from vllm.model_executor.layers.mamba.ops.ssd_combined import (
|
|||||||
mamba_chunk_scan_combined_varlen)
|
mamba_chunk_scan_combined_varlen)
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.v1.attention.backends.mamba2_attn import (
|
from vllm.v1.attention.backends.mamba2_attn import (
|
||||||
_query_start_loc_to_chunk_indices_offsets)
|
compute_varlen_chunk_metadata)
|
||||||
|
|
||||||
# Added by the IBM Team, 2024
|
# Added by the IBM Team, 2024
|
||||||
|
|
||||||
@ -225,13 +225,9 @@ def test_mamba_chunk_scan_single_example(d_head, n_heads, seq_len_chunk_size,
|
|||||||
Y_min, final_state_min = ssd_minimal_discrete(X * dt.unsqueeze(-1), A * dt,
|
Y_min, final_state_min = ssd_minimal_discrete(X * dt.unsqueeze(-1), A * dt,
|
||||||
B, C, chunk_size)
|
B, C, chunk_size)
|
||||||
|
|
||||||
cu_seqlens = torch.tensor((0, seqlen), device='cuda').cumsum(dim=0)
|
cu_seqlens = torch.tensor((0, seqlen), device="cuda").cumsum(dim=0)
|
||||||
seq_idx = torch.zeros(seqlen, dtype=torch.int32, device=cu_seqlens.device)
|
cu_chunk_seqlens, last_chunk_indices, seq_idx_chunks = (
|
||||||
|
compute_varlen_chunk_metadata(cu_seqlens, chunk_size))
|
||||||
chunk_indices, chunk_offsets = \
|
|
||||||
_query_start_loc_to_chunk_indices_offsets(
|
|
||||||
cu_seqlens, chunk_size, cu_seqlens[-1])
|
|
||||||
|
|
||||||
# varlen has implicit batch=1
|
# varlen has implicit batch=1
|
||||||
X = X.squeeze(0)
|
X = X.squeeze(0)
|
||||||
dt = dt.squeeze(0)
|
dt = dt.squeeze(0)
|
||||||
@ -239,18 +235,20 @@ def test_mamba_chunk_scan_single_example(d_head, n_heads, seq_len_chunk_size,
|
|||||||
B = B.squeeze(0)
|
B = B.squeeze(0)
|
||||||
C = C.squeeze(0)
|
C = C.squeeze(0)
|
||||||
Y = torch.empty_like(X)
|
Y = torch.empty_like(X)
|
||||||
final_state = mamba_chunk_scan_combined_varlen(X,
|
final_state = mamba_chunk_scan_combined_varlen(
|
||||||
dt,
|
X,
|
||||||
A,
|
dt,
|
||||||
B,
|
A,
|
||||||
C,
|
B,
|
||||||
chunk_size,
|
C,
|
||||||
D=None,
|
chunk_size,
|
||||||
cu_seqlens=cu_seqlens,
|
cu_seqlens=cu_seqlens.to(torch.int32),
|
||||||
seq_idx=seq_idx,
|
cu_chunk_seqlens=cu_chunk_seqlens,
|
||||||
chunk_indices=chunk_indices,
|
last_chunk_indices=last_chunk_indices,
|
||||||
chunk_offsets=chunk_offsets,
|
seq_idx=seq_idx_chunks,
|
||||||
out=Y)
|
out=Y,
|
||||||
|
D=None,
|
||||||
|
)
|
||||||
|
|
||||||
# just test the last in sequence
|
# just test the last in sequence
|
||||||
torch.testing.assert_close(Y[-1], Y_min[0, -1], atol=atol, rtol=rtol)
|
torch.testing.assert_close(Y[-1], Y_min[0, -1], atol=atol, rtol=rtol)
|
||||||
@ -312,14 +310,13 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases,
|
|||||||
exhausted: dict = {} # map: eg -> boolean indicating example is exhausted
|
exhausted: dict = {} # map: eg -> boolean indicating example is exhausted
|
||||||
|
|
||||||
states = None
|
states = None
|
||||||
for Y_min, cu_seqlens, seq_idx, (
|
for Y_min, cu_seqlens, _token_seq_idx, (
|
||||||
A, dt, X, B, C) in generate_continuous_batched_examples(
|
A, dt, X, B, C) in generate_continuous_batched_examples(
|
||||||
cases, num_examples, seqlen, last_taken, exhausted, n_heads,
|
cases, num_examples, seqlen, last_taken, exhausted, n_heads,
|
||||||
d_head, itype):
|
d_head, itype):
|
||||||
|
|
||||||
chunk_indices, chunk_offsets = \
|
cu_chunk_seqlens, last_chunk_indices, seq_idx_chunks = (
|
||||||
_query_start_loc_to_chunk_indices_offsets(
|
compute_varlen_chunk_metadata(cu_seqlens, chunk_size))
|
||||||
cu_seqlens, chunk_size, cu_seqlens[-1])
|
|
||||||
|
|
||||||
Y = torch.empty_like(X)
|
Y = torch.empty_like(X)
|
||||||
new_states = mamba_chunk_scan_combined_varlen(
|
new_states = mamba_chunk_scan_combined_varlen(
|
||||||
@ -329,13 +326,13 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases,
|
|||||||
B,
|
B,
|
||||||
C,
|
C,
|
||||||
chunk_size,
|
chunk_size,
|
||||||
D=None,
|
cu_seqlens=cu_seqlens.to(torch.int32),
|
||||||
cu_seqlens=cu_seqlens,
|
cu_chunk_seqlens=cu_chunk_seqlens,
|
||||||
seq_idx=seq_idx,
|
last_chunk_indices=last_chunk_indices,
|
||||||
chunk_indices=chunk_indices,
|
seq_idx=seq_idx_chunks,
|
||||||
chunk_offsets=chunk_offsets,
|
|
||||||
initial_states=states,
|
|
||||||
out=Y,
|
out=Y,
|
||||||
|
D=None,
|
||||||
|
initial_states=states,
|
||||||
)
|
)
|
||||||
|
|
||||||
# just test the last in sequence
|
# just test the last in sequence
|
||||||
@ -403,9 +400,8 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens):
|
|||||||
device = X.device
|
device = X.device
|
||||||
|
|
||||||
## full seqlen computation
|
## full seqlen computation
|
||||||
chunk_indices, chunk_offsets = \
|
cu_chunk_seqlens, last_chunk_indices, seq_idx_chunks = (
|
||||||
_query_start_loc_to_chunk_indices_offsets(
|
compute_varlen_chunk_metadata(cu_seqlens, chunk_size))
|
||||||
cu_seqlens, chunk_size, cu_seqlens[-1])
|
|
||||||
Y_ref = torch.empty_like(X)
|
Y_ref = torch.empty_like(X)
|
||||||
state_ref = mamba_chunk_scan_combined_varlen(
|
state_ref = mamba_chunk_scan_combined_varlen(
|
||||||
X,
|
X,
|
||||||
@ -414,13 +410,13 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens):
|
|||||||
B,
|
B,
|
||||||
C,
|
C,
|
||||||
chunk_size,
|
chunk_size,
|
||||||
D=None,
|
cu_seqlens=cu_seqlens.to(torch.int32),
|
||||||
cu_seqlens=cu_seqlens,
|
cu_chunk_seqlens=cu_chunk_seqlens,
|
||||||
seq_idx=seq_idx,
|
last_chunk_indices=last_chunk_indices,
|
||||||
chunk_indices=chunk_indices,
|
seq_idx=seq_idx_chunks,
|
||||||
chunk_offsets=chunk_offsets,
|
|
||||||
initial_states=None,
|
|
||||||
out=Y_ref,
|
out=Y_ref,
|
||||||
|
D=None,
|
||||||
|
initial_states=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
## chunked seqlen computation
|
## chunked seqlen computation
|
||||||
@ -431,10 +427,6 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens):
|
|||||||
torch.cumsum(chunked_seqlens, dim=0)
|
torch.cumsum(chunked_seqlens, dim=0)
|
||||||
],
|
],
|
||||||
dim=0)
|
dim=0)
|
||||||
chunked_seq_idx = torch.repeat_interleave(
|
|
||||||
torch.arange(len(chunked_seqlens), device=device),
|
|
||||||
chunked_seqlens,
|
|
||||||
output_size=chunked_cu_seqlens[-1]).to(torch.int32)
|
|
||||||
chunked_input_seq_len = chunked_cu_seqlens[-1]
|
chunked_input_seq_len = chunked_cu_seqlens[-1]
|
||||||
X_chunked = torch.zeros_like(X)[:chunked_input_seq_len, ...]
|
X_chunked = torch.zeros_like(X)[:chunked_input_seq_len, ...]
|
||||||
dt_chunked = torch.zeros_like(dt)[:chunked_input_seq_len, ...]
|
dt_chunked = torch.zeros_like(dt)[:chunked_input_seq_len, ...]
|
||||||
@ -450,9 +442,8 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens):
|
|||||||
C_chunked[chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1], ...] = chunk_f(C, i) # noqa: E501
|
C_chunked[chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1], ...] = chunk_f(C, i) # noqa: E501
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
chunk_indices, chunk_offsets = \
|
cu_chunk_seqlens, last_chunk_indices, seq_idx_chunks = (
|
||||||
_query_start_loc_to_chunk_indices_offsets(
|
compute_varlen_chunk_metadata(chunked_cu_seqlens, chunk_size))
|
||||||
chunked_cu_seqlens, chunk_size, chunked_cu_seqlens[-1])
|
|
||||||
Y_partial = torch.empty_like(X_chunked)
|
Y_partial = torch.empty_like(X_chunked)
|
||||||
partial_state = mamba_chunk_scan_combined_varlen(
|
partial_state = mamba_chunk_scan_combined_varlen(
|
||||||
X_chunked,
|
X_chunked,
|
||||||
@ -461,13 +452,13 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens):
|
|||||||
B_chunked,
|
B_chunked,
|
||||||
C_chunked,
|
C_chunked,
|
||||||
chunk_size,
|
chunk_size,
|
||||||
D=None,
|
cu_seqlens=chunked_cu_seqlens.to(torch.int32),
|
||||||
cu_seqlens=chunked_cu_seqlens,
|
cu_chunk_seqlens=cu_chunk_seqlens,
|
||||||
seq_idx=chunked_seq_idx,
|
last_chunk_indices=last_chunk_indices,
|
||||||
chunk_indices=chunk_indices,
|
seq_idx=seq_idx_chunks,
|
||||||
chunk_offsets=chunk_offsets,
|
|
||||||
initial_states=None,
|
|
||||||
out=Y_partial,
|
out=Y_partial,
|
||||||
|
D=None,
|
||||||
|
initial_states=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
# remaining chunk
|
# remaining chunk
|
||||||
@ -477,10 +468,6 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens):
|
|||||||
torch.cumsum(remaining_chunked_seqlens, dim=0)
|
torch.cumsum(remaining_chunked_seqlens, dim=0)
|
||||||
],
|
],
|
||||||
dim=0)
|
dim=0)
|
||||||
remaining_chunked_seq_idx = torch.repeat_interleave(
|
|
||||||
torch.arange(len(remaining_chunked_seqlens), device=device),
|
|
||||||
remaining_chunked_seqlens,
|
|
||||||
output_size=remaining_chunked_cu_seqlens[-1]).to(torch.int32)
|
|
||||||
remaining_chunked_input_seq_len = remaining_chunked_cu_seqlens[-1]
|
remaining_chunked_input_seq_len = remaining_chunked_cu_seqlens[-1]
|
||||||
# fmt: off
|
# fmt: off
|
||||||
remaining_X_chunked = torch.zeros_like(X)[:remaining_chunked_input_seq_len, ...] # noqa: E501
|
remaining_X_chunked = torch.zeros_like(X)[:remaining_chunked_input_seq_len, ...] # noqa: E501
|
||||||
@ -509,11 +496,9 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens):
|
|||||||
assert concat_batch_f(B_chunked, remaining_B_chunked).equal(B)
|
assert concat_batch_f(B_chunked, remaining_B_chunked).equal(B)
|
||||||
assert concat_batch_f(C_chunked, remaining_C_chunked).equal(C)
|
assert concat_batch_f(C_chunked, remaining_C_chunked).equal(C)
|
||||||
|
|
||||||
chunk_indices, chunk_offsets = \
|
cu_chunk_seqlens, last_chunk_indices, seq_idx_chunks = (
|
||||||
_query_start_loc_to_chunk_indices_offsets(
|
compute_varlen_chunk_metadata(remaining_chunked_cu_seqlens,
|
||||||
remaining_chunked_cu_seqlens,
|
chunk_size))
|
||||||
chunk_size,
|
|
||||||
remaining_chunked_cu_seqlens[-1])
|
|
||||||
|
|
||||||
Y_chunked = torch.empty_like(remaining_X_chunked)
|
Y_chunked = torch.empty_like(remaining_X_chunked)
|
||||||
state_chunked = mamba_chunk_scan_combined_varlen(
|
state_chunked = mamba_chunk_scan_combined_varlen(
|
||||||
@ -523,13 +508,13 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens):
|
|||||||
remaining_B_chunked,
|
remaining_B_chunked,
|
||||||
remaining_C_chunked,
|
remaining_C_chunked,
|
||||||
chunk_size,
|
chunk_size,
|
||||||
D=None,
|
cu_seqlens=remaining_chunked_cu_seqlens.to(torch.int32),
|
||||||
cu_seqlens=remaining_chunked_cu_seqlens,
|
cu_chunk_seqlens=cu_chunk_seqlens,
|
||||||
seq_idx=remaining_chunked_seq_idx,
|
last_chunk_indices=last_chunk_indices,
|
||||||
chunk_indices=chunk_indices,
|
seq_idx=seq_idx_chunks,
|
||||||
chunk_offsets=chunk_offsets,
|
|
||||||
initial_states=partial_state,
|
|
||||||
out=Y_chunked,
|
out=Y_chunked,
|
||||||
|
D=None,
|
||||||
|
initial_states=partial_state,
|
||||||
)
|
)
|
||||||
Y = concat_batch_f(Y_partial, Y_chunked)
|
Y = concat_batch_f(Y_partial, Y_chunked)
|
||||||
|
|
||||||
|
|||||||
@ -1,5 +1,6 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
import itertools
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
@ -17,6 +18,75 @@ from vllm.v1.attention.backends.utils import (PAD_SLOT_ID,
|
|||||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||||
|
|
||||||
|
|
||||||
|
def compute_varlen_chunk_metadata(
|
||||||
|
query_start_loc: torch.Tensor,
|
||||||
|
chunk_size: int,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Build chunk-aligned, variable-length metadata used by Mamba2 SSD kernels.
|
||||||
|
|
||||||
|
Given per-sequence cumulative token starts `query_start_loc` of shape [B+1]
|
||||||
|
and a physical `chunk_size`, returns three tensors on the same device:
|
||||||
|
- cu_chunk_seqlens: (nchunks+1,) int32 exclusive prefix-sum of
|
||||||
|
logical-chunk lengths (each logical chunk never crosses a sequence or
|
||||||
|
physical-chunk boundary).
|
||||||
|
- last_chunk_indices: (B,) int32 index of the last logical chunk
|
||||||
|
for each sequence (=-1 for empty sequences).
|
||||||
|
- seq_idx_chunks: (nchunks,) int32 sequence index for each logical
|
||||||
|
chunk in order.
|
||||||
|
|
||||||
|
This is intentionally lightweight and CPU-side; it mirrors the metadata
|
||||||
|
produced by the V1 Mamba2 meta-data builder and is exported so tests
|
||||||
|
(and other callers) can avoid duplicating the logic.
|
||||||
|
"""
|
||||||
|
assert query_start_loc.ndim == 1, "query_start_loc must be 1-D [B+1]"
|
||||||
|
assert int(query_start_loc[0].item()) == 0, "query_start_loc[0] must be 0"
|
||||||
|
device = query_start_loc.device
|
||||||
|
|
||||||
|
qsl64 = query_start_loc.to(torch.int64)
|
||||||
|
starts = qsl64[:-1].tolist()
|
||||||
|
ends = qsl64[1:].tolist()
|
||||||
|
total = int(qsl64[-1].item())
|
||||||
|
|
||||||
|
chunk_lens: list[int] = []
|
||||||
|
seq_idx_chunks: list[int] = []
|
||||||
|
last_chunk_indices: list[int] = [-1] * len(starts)
|
||||||
|
|
||||||
|
for b, (s, e) in enumerate(zip(starts, ends)):
|
||||||
|
if e <= s:
|
||||||
|
# empty sequence
|
||||||
|
continue
|
||||||
|
pos = s
|
||||||
|
while pos < e:
|
||||||
|
# split at both sequence boundaries and physical chunk boundaries
|
||||||
|
room = chunk_size - (pos % chunk_size)
|
||||||
|
take = min(room, e - pos)
|
||||||
|
chunk_lens.append(int(take))
|
||||||
|
seq_idx_chunks.append(b)
|
||||||
|
last_chunk_indices[b] = len(chunk_lens) - 1
|
||||||
|
pos += take
|
||||||
|
|
||||||
|
# Exclusive prefix sum over logical-chunk lengths
|
||||||
|
if chunk_lens:
|
||||||
|
cu_chunk_seqlens = torch.tensor([0] +
|
||||||
|
list(itertools.accumulate(chunk_lens)),
|
||||||
|
device=device,
|
||||||
|
dtype=torch.int32)
|
||||||
|
# Final boundary must equal total tokens
|
||||||
|
assert int(cu_chunk_seqlens[-1].item()) == total
|
||||||
|
else:
|
||||||
|
cu_chunk_seqlens = torch.tensor([0], device=device, dtype=torch.int32)
|
||||||
|
|
||||||
|
last_chunk_indices_t = (torch.tensor(
|
||||||
|
last_chunk_indices, device=device, dtype=torch.int32)
|
||||||
|
if len(starts) > 0 else torch.empty(
|
||||||
|
(0, ), device=device, dtype=torch.int32))
|
||||||
|
seq_idx_chunks_t = torch.tensor(seq_idx_chunks,
|
||||||
|
device=device,
|
||||||
|
dtype=torch.int32)
|
||||||
|
return cu_chunk_seqlens, last_chunk_indices_t, seq_idx_chunks_t
|
||||||
|
|
||||||
|
|
||||||
class Mamba2AttentionBackend(AttentionBackend):
|
class Mamba2AttentionBackend(AttentionBackend):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user