mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 01:55:01 +08:00
[Attention][DBO] Add support for "splitting" the CommonAttentionMetadata (#21153)
Signed-off-by: Sage Moore <sage@neuralmagic.com>
This commit is contained in:
parent
6e8d8c4afb
commit
0edaf752d7
157
tests/v1/attention/test_attention_splitting.py
Normal file
157
tests/v1/attention/test_attention_splitting.py
Normal file
@ -0,0 +1,157 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.v1.attention.test_attention_backends import BATCH_SPECS
|
||||
from tests.v1.attention.utils import create_common_attn_metadata
|
||||
from vllm.v1.attention.backends.utils import (UbatchSlice,
|
||||
_make_metadata_with_slice,
|
||||
slice_query_start_locs,
|
||||
split_attn_metadata)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_query_start_loc():
|
||||
"""Sample query_start_loc tensor for testing"""
|
||||
return torch.tensor([0, 5, 12, 20, 35, 50])
|
||||
|
||||
|
||||
def test_basic_slice_middle(sample_query_start_loc):
|
||||
"""Test slicing from middle of tensor"""
|
||||
req_slice = slice(1, 3) # slice from index 1 to 3
|
||||
result = slice_query_start_locs(sample_query_start_loc, req_slice)
|
||||
|
||||
expected = torch.tensor([0, 7, 15])
|
||||
assert torch.equal(result, expected)
|
||||
|
||||
|
||||
def test_slice_from_beginning(sample_query_start_loc):
|
||||
"""Test slicing from the beginning of tensor"""
|
||||
req_slice = slice(0, 2) # slice from index 0 to 2
|
||||
result = slice_query_start_locs(sample_query_start_loc, req_slice)
|
||||
|
||||
expected = torch.tensor([0, 5, 12])
|
||||
assert torch.equal(result, expected)
|
||||
|
||||
|
||||
def test_slice_to_end(sample_query_start_loc):
|
||||
"""Test slicing to the end of tensor"""
|
||||
req_slice = slice(3, 5) # slice from index 3 to 5 (last index)
|
||||
result = slice_query_start_locs(sample_query_start_loc, req_slice)
|
||||
|
||||
expected = torch.tensor([0, 15, 30])
|
||||
assert torch.equal(result, expected)
|
||||
|
||||
|
||||
def test_single_element_slice(sample_query_start_loc):
|
||||
"""Test slice that results in single element"""
|
||||
req_slice = slice(2, 3) # slice from index 2 to 3
|
||||
result = slice_query_start_locs(sample_query_start_loc, req_slice)
|
||||
|
||||
expected = torch.tensor([0, 8])
|
||||
assert torch.equal(result, expected)
|
||||
|
||||
|
||||
def test_full_tensor_slice(sample_query_start_loc):
|
||||
"""Test slicing the entire tensor"""
|
||||
req_slice = slice(0, 5) # slice entire tensor
|
||||
result = slice_query_start_locs(sample_query_start_loc, req_slice)
|
||||
|
||||
expected = torch.tensor([0, 5, 12, 20, 35, 50])
|
||||
assert torch.equal(result, expected)
|
||||
|
||||
|
||||
def test_slice_bounds_edge_cases(sample_query_start_loc):
|
||||
# Test slice that goes exactly to the last element
|
||||
req_slice = slice(4, 5) # Last index
|
||||
result = slice_query_start_locs(sample_query_start_loc, req_slice)
|
||||
|
||||
expected = torch.tensor([0, 15])
|
||||
assert torch.equal(result, expected)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def small_decode_metadata():
|
||||
"""Create metadata for small decode batch"""
|
||||
batch_spec = BATCH_SPECS["small_decode"]
|
||||
device = torch.device("cpu")
|
||||
return create_common_attn_metadata(batch_spec,
|
||||
block_size=16,
|
||||
device=device)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def large_decode_metadata():
|
||||
"""Create metadata for small decode batch"""
|
||||
batch_spec = BATCH_SPECS["large_decode"]
|
||||
device = torch.device("cpu")
|
||||
return create_common_attn_metadata(batch_spec,
|
||||
block_size=16,
|
||||
device=device)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mixed_small_metadata():
|
||||
"""Create metadata for mixed small batch"""
|
||||
batch_spec = BATCH_SPECS["mixed_small"]
|
||||
device = torch.device("cpu")
|
||||
return create_common_attn_metadata(batch_spec,
|
||||
block_size=16,
|
||||
device=device)
|
||||
|
||||
|
||||
# Tests for _make_metadata_with_slice
|
||||
def test_make_metadata_with_slice_decode_batch(small_decode_metadata):
|
||||
"""Test slicing decode batch metadata"""
|
||||
# Split first request only
|
||||
ubatch_slice = UbatchSlice(slice(0, 1), slice(0, 1))
|
||||
|
||||
result = _make_metadata_with_slice(ubatch_slice, small_decode_metadata)
|
||||
|
||||
# Check sliced results
|
||||
assert result.num_reqs == 1 # slice(0, 1) gives 1 requests
|
||||
assert result.num_actual_tokens == 1 # slice(0, 1) gives 1 token
|
||||
assert result.max_query_len == 1
|
||||
assert torch.equal(result.query_start_loc, torch.tensor([0, 1]))
|
||||
assert torch.equal(result.seq_lens, torch.tensor([32]))
|
||||
|
||||
|
||||
def test_make_metadata_with_slice_mixed_batch(mixed_small_metadata):
|
||||
"""Test slicing mixed batch metadata"""
|
||||
ubatch_slice = UbatchSlice(slice(1, 3),
|
||||
slice(1, 7)) # Requests 1-3, tokens 1-7
|
||||
|
||||
result = _make_metadata_with_slice(ubatch_slice, mixed_small_metadata)
|
||||
|
||||
assert result.num_reqs == 2 # slice(1, 3) gives 2 requests
|
||||
assert result.num_actual_tokens == 6 # slice(1, 7) gives 6 tokens
|
||||
assert result.max_query_len == 5
|
||||
assert torch.equal(result.query_start_loc, torch.tensor([0, 1, 6]))
|
||||
assert torch.equal(result.seq_lens, torch.tensor([40, 48]))
|
||||
|
||||
|
||||
def test_split_attn_metadata_decode_batch(large_decode_metadata):
|
||||
"""Test splitting decode batch into two equal parts"""
|
||||
num_tokens = large_decode_metadata.num_reqs
|
||||
mid_point = num_tokens // 2
|
||||
ubatch_slices = [
|
||||
UbatchSlice(slice(0, mid_point), slice(0, mid_point)),
|
||||
UbatchSlice(slice(mid_point, num_tokens), slice(mid_point,
|
||||
num_tokens)),
|
||||
]
|
||||
|
||||
results = split_attn_metadata(ubatch_slices, large_decode_metadata)
|
||||
|
||||
assert len(results) == 2
|
||||
|
||||
# Check first split
|
||||
assert results[0].num_reqs == mid_point
|
||||
assert results[0].num_actual_tokens == mid_point
|
||||
assert torch.equal(results[0].seq_lens, torch.tensor([2048] * mid_point))
|
||||
|
||||
# Check second split
|
||||
assert results[1].num_reqs == mid_point
|
||||
assert results[1].num_actual_tokens == mid_point
|
||||
assert torch.equal(results[1].seq_lens, torch.tensor([2048] * mid_point))
|
||||
@ -63,6 +63,89 @@ class CommonAttentionMetadata:
|
||||
causal: bool = True
|
||||
|
||||
|
||||
@dataclass
|
||||
class UbatchSlice:
|
||||
request_slice: slice
|
||||
token_slice: slice
|
||||
|
||||
|
||||
def slice_query_start_locs(
|
||||
query_start_loc: torch.Tensor,
|
||||
request_slice: slice,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Creates a new query_start_loc that corresponds to the requests in
|
||||
request_slice.
|
||||
|
||||
Note: This function creates a new tensor to hold the new query_start_locs.
|
||||
This will break cudagraph compatibility.
|
||||
"""
|
||||
return query_start_loc[request_slice.start: request_slice.stop + 1] -\
|
||||
query_start_loc[request_slice.start]
|
||||
|
||||
|
||||
def _make_metadata_with_slice(
|
||||
ubatch_slice: UbatchSlice,
|
||||
attn_metadata: CommonAttentionMetadata) -> CommonAttentionMetadata:
|
||||
"""
|
||||
This function creates a new CommonAttentionMetadata that corresponds to
|
||||
the requests included in ubatch_slice
|
||||
"""
|
||||
|
||||
request_slice = ubatch_slice.request_slice
|
||||
token_slice = ubatch_slice.token_slice
|
||||
|
||||
query_start_loc = slice_query_start_locs(attn_metadata.query_start_loc,
|
||||
request_slice)
|
||||
assert len(query_start_loc >= 2)
|
||||
query_start_loc_cpu = slice_query_start_locs(
|
||||
attn_metadata.query_start_loc_cpu, request_slice)
|
||||
|
||||
seq_lens = attn_metadata.seq_lens[request_slice]
|
||||
seq_lens_cpu = attn_metadata.seq_lens_cpu[request_slice]
|
||||
num_computed_tokens_cpu = attn_metadata.num_computed_tokens_cpu[
|
||||
request_slice]
|
||||
|
||||
num_requests = request_slice.stop - request_slice.start
|
||||
num_actual_tokens = token_slice.stop - token_slice.start
|
||||
max_query_len = int(
|
||||
torch.max(torch.abs(query_start_loc_cpu[1:] -
|
||||
query_start_loc_cpu[:-1])).item())
|
||||
|
||||
block_table_tensor = attn_metadata.block_table_tensor[request_slice]
|
||||
slot_mapping = attn_metadata.slot_mapping[token_slice]
|
||||
|
||||
return CommonAttentionMetadata(
|
||||
query_start_loc=query_start_loc,
|
||||
query_start_loc_cpu=query_start_loc_cpu,
|
||||
seq_lens=seq_lens,
|
||||
seq_lens_cpu=seq_lens_cpu,
|
||||
num_computed_tokens_cpu=num_computed_tokens_cpu,
|
||||
num_reqs=num_requests,
|
||||
num_actual_tokens=num_actual_tokens,
|
||||
max_query_len=max_query_len,
|
||||
block_table_tensor=block_table_tensor,
|
||||
slot_mapping=slot_mapping,
|
||||
)
|
||||
|
||||
|
||||
def split_attn_metadata(
|
||||
ubatch_slices: list[UbatchSlice],
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
) -> list[CommonAttentionMetadata]:
|
||||
"""
|
||||
Creates a new CommonAttentionMetadata instance that corresponds to the
|
||||
requests for each UbatchSlice in ubatch_slices.
|
||||
|
||||
Note: This function does not modify common_attn_metadata
|
||||
"""
|
||||
results = []
|
||||
for ubatch_slice in ubatch_slices:
|
||||
results.append(
|
||||
_make_metadata_with_slice(ubatch_slice, common_attn_metadata))
|
||||
return results
|
||||
|
||||
|
||||
M = TypeVar("M")
|
||||
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user