diff --git a/tests/v1/attention/test_attention_splitting.py b/tests/v1/attention/test_attention_splitting.py new file mode 100644 index 0000000000000..3fc1011d5042e --- /dev/null +++ b/tests/v1/attention/test_attention_splitting.py @@ -0,0 +1,157 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest +import torch + +from tests.v1.attention.test_attention_backends import BATCH_SPECS +from tests.v1.attention.utils import create_common_attn_metadata +from vllm.v1.attention.backends.utils import (UbatchSlice, + _make_metadata_with_slice, + slice_query_start_locs, + split_attn_metadata) + + +@pytest.fixture +def sample_query_start_loc(): + """Sample query_start_loc tensor for testing""" + return torch.tensor([0, 5, 12, 20, 35, 50]) + + +def test_basic_slice_middle(sample_query_start_loc): + """Test slicing from middle of tensor""" + req_slice = slice(1, 3) # slice from index 1 to 3 + result = slice_query_start_locs(sample_query_start_loc, req_slice) + + expected = torch.tensor([0, 7, 15]) + assert torch.equal(result, expected) + + +def test_slice_from_beginning(sample_query_start_loc): + """Test slicing from the beginning of tensor""" + req_slice = slice(0, 2) # slice from index 0 to 2 + result = slice_query_start_locs(sample_query_start_loc, req_slice) + + expected = torch.tensor([0, 5, 12]) + assert torch.equal(result, expected) + + +def test_slice_to_end(sample_query_start_loc): + """Test slicing to the end of tensor""" + req_slice = slice(3, 5) # slice from index 3 to 5 (last index) + result = slice_query_start_locs(sample_query_start_loc, req_slice) + + expected = torch.tensor([0, 15, 30]) + assert torch.equal(result, expected) + + +def test_single_element_slice(sample_query_start_loc): + """Test slice that results in single element""" + req_slice = slice(2, 3) # slice from index 2 to 3 + result = slice_query_start_locs(sample_query_start_loc, req_slice) + + expected = torch.tensor([0, 8]) + assert torch.equal(result, expected) + + +def test_full_tensor_slice(sample_query_start_loc): + """Test slicing the entire tensor""" + req_slice = slice(0, 5) # slice entire tensor + result = slice_query_start_locs(sample_query_start_loc, req_slice) + + expected = torch.tensor([0, 5, 12, 20, 35, 50]) + assert torch.equal(result, expected) + + +def test_slice_bounds_edge_cases(sample_query_start_loc): + # Test slice that goes exactly to the last element + req_slice = slice(4, 5) # Last index + result = slice_query_start_locs(sample_query_start_loc, req_slice) + + expected = torch.tensor([0, 15]) + assert torch.equal(result, expected) + + +@pytest.fixture +def small_decode_metadata(): + """Create metadata for small decode batch""" + batch_spec = BATCH_SPECS["small_decode"] + device = torch.device("cpu") + return create_common_attn_metadata(batch_spec, + block_size=16, + device=device) + + +@pytest.fixture +def large_decode_metadata(): + """Create metadata for small decode batch""" + batch_spec = BATCH_SPECS["large_decode"] + device = torch.device("cpu") + return create_common_attn_metadata(batch_spec, + block_size=16, + device=device) + + +@pytest.fixture +def mixed_small_metadata(): + """Create metadata for mixed small batch""" + batch_spec = BATCH_SPECS["mixed_small"] + device = torch.device("cpu") + return create_common_attn_metadata(batch_spec, + block_size=16, + device=device) + + +# Tests for _make_metadata_with_slice +def test_make_metadata_with_slice_decode_batch(small_decode_metadata): + """Test slicing decode batch metadata""" + # Split first request only + ubatch_slice = UbatchSlice(slice(0, 1), slice(0, 1)) + + result = _make_metadata_with_slice(ubatch_slice, small_decode_metadata) + + # Check sliced results + assert result.num_reqs == 1 # slice(0, 1) gives 1 requests + assert result.num_actual_tokens == 1 # slice(0, 1) gives 1 token + assert result.max_query_len == 1 + assert torch.equal(result.query_start_loc, torch.tensor([0, 1])) + assert torch.equal(result.seq_lens, torch.tensor([32])) + + +def test_make_metadata_with_slice_mixed_batch(mixed_small_metadata): + """Test slicing mixed batch metadata""" + ubatch_slice = UbatchSlice(slice(1, 3), + slice(1, 7)) # Requests 1-3, tokens 1-7 + + result = _make_metadata_with_slice(ubatch_slice, mixed_small_metadata) + + assert result.num_reqs == 2 # slice(1, 3) gives 2 requests + assert result.num_actual_tokens == 6 # slice(1, 7) gives 6 tokens + assert result.max_query_len == 5 + assert torch.equal(result.query_start_loc, torch.tensor([0, 1, 6])) + assert torch.equal(result.seq_lens, torch.tensor([40, 48])) + + +def test_split_attn_metadata_decode_batch(large_decode_metadata): + """Test splitting decode batch into two equal parts""" + num_tokens = large_decode_metadata.num_reqs + mid_point = num_tokens // 2 + ubatch_slices = [ + UbatchSlice(slice(0, mid_point), slice(0, mid_point)), + UbatchSlice(slice(mid_point, num_tokens), slice(mid_point, + num_tokens)), + ] + + results = split_attn_metadata(ubatch_slices, large_decode_metadata) + + assert len(results) == 2 + + # Check first split + assert results[0].num_reqs == mid_point + assert results[0].num_actual_tokens == mid_point + assert torch.equal(results[0].seq_lens, torch.tensor([2048] * mid_point)) + + # Check second split + assert results[1].num_reqs == mid_point + assert results[1].num_actual_tokens == mid_point + assert torch.equal(results[1].seq_lens, torch.tensor([2048] * mid_point)) diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index 68c51340a2ac4..2fb768c9d710d 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -76,6 +76,7 @@ def slice_query_start_locs( """ Creates a new query_start_loc that corresponds to the requests in request_slice. + Note: This function creates a new tensor to hold the new query_start_locs. This will break cudagraph compatibility. """ @@ -129,19 +130,19 @@ def _make_metadata_with_slice( def split_attn_metadata( - ubatch_slices: list[tuple[slice, slice]], + ubatch_slices: list[UbatchSlice], common_attn_metadata: CommonAttentionMetadata, ) -> list[CommonAttentionMetadata]: """ Creates a new CommonAttentionMetadata instance that corresponds to the requests for each UbatchSlice in ubatch_slices. + Note: This function does not modify common_attn_metadata """ results = [] for ubatch_slice in ubatch_slices: - s = UbatchSlice(request_slice=ubatch_slice[0], - token_slice=ubatch_slice[1]) - results.append(_make_metadata_with_slice(s, common_attn_metadata)) + results.append( + _make_metadata_with_slice(ubatch_slice, common_attn_metadata)) return results diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index c9bf227a872ea..1d5fa4b9f7d68 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -52,7 +52,7 @@ from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata from vllm.v1.attention.backends.mamba_selectors import get_mamba_attn_backend from vllm.v1.attention.backends.utils import ( AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata, - make_kv_sharing_fast_prefill_attention_metadata, + UbatchSlice, make_kv_sharing_fast_prefill_attention_metadata, make_local_attention_virtual_batches, split_attn_metadata) from vllm.v1.core.encoder_cache_manager import compute_encoder_budget from vllm.v1.kv_cache_interface import (AttentionSpec, @@ -100,7 +100,6 @@ AttnMetadataDict: TypeAlias = dict[str, FlashAttentionMetadata] PerLayerAttnMetadata: TypeAlias = Union[list[AttnMetadataDict], AttnMetadataDict] -UbatchSlice: TypeAlias = tuple[slice, slice] UBatchSlices: TypeAlias = list[UbatchSlice] @@ -656,10 +655,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): assert b0_reqs_end < num_reqs and \ b0_tokens_end < total_num_scheduled_tokens ubatch_slices = [ - (slice(0, b0_reqs_end), slice(0, b0_tokens_end)), - (slice(b0_reqs_end, - num_reqs), slice(b0_tokens_end, - total_num_scheduled_tokens)), + UbatchSlice(slice(0, b0_reqs_end), slice(0, b0_tokens_end)), + UbatchSlice(slice(b0_reqs_end, num_reqs), + slice(b0_tokens_end, total_num_scheduled_tokens)), ] # Compute ubatch padding. This currently only accounts for DP padding @@ -1595,10 +1593,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): first_ubatch_slice = ubatch_slices[0] second_ubatch_slice = ubatch_slices[1] - first_ubatch_num_tokens = first_ubatch_slice[ - 1].stop - first_ubatch_slice[1].start - second_ubatch_num_tokens = second_ubatch_slice[ - 1].stop - second_ubatch_slice[1].start + first_ubatch_num_tokens = first_ubatch_slice.token_slice.stop - \ + first_ubatch_slice.token_slice.start + second_ubatch_num_tokens = second_ubatch_slice.token_slice.stop - \ + second_ubatch_slice.token_slice.start # We don't support prefills yet so the two ubatches should only differ # by at most one token assert abs(first_ubatch_num_tokens - second_ubatch_num_tokens) <= 1 @@ -1635,7 +1633,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # slicing but before attention meta data creation def pad_out_ubatch_first_stage(self, ubatch_slices: UBatchSlices, num_pad_tokens: int): - original_num_tokens = ubatch_slices[1][1].stop + original_num_tokens = ubatch_slices[1].token_slice.stop assert num_pad_tokens < original_num_tokens total_num_tokens_per_ubatch = (original_num_tokens + num_pad_tokens) // 2 @@ -1643,10 +1641,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): padded_second_ubatch_slice = slice(total_num_tokens_per_ubatch, original_num_tokens) - ubatch_slices[0] = (padded_first_ubatch_slice, - padded_first_ubatch_slice) - ubatch_slices[1] = (padded_second_ubatch_slice, - padded_second_ubatch_slice) + ubatch_slices[0] = UbatchSlice(padded_first_ubatch_slice, + padded_first_ubatch_slice) + ubatch_slices[1] = UbatchSlice(padded_second_ubatch_slice, + padded_second_ubatch_slice) # This is where the second ubatch is adjusted to account for the padding. # Should be called after attention metadata creation. This just pads @@ -1655,10 +1653,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): def pad_out_ubatch_second_stage(self, ubatch_slices: UBatchSlices, num_total_tokens: int): # TODO Add asserts to make sure stage one ran - padded_second_ubatch_slice = slice(ubatch_slices[1][1].start, + padded_second_ubatch_slice = slice(ubatch_slices[1].token_slice.start, num_total_tokens) - ubatch_slices[1] = (padded_second_ubatch_slice, - padded_second_ubatch_slice) + ubatch_slices[1] = UbatchSlice(padded_second_ubatch_slice, + padded_second_ubatch_slice) def should_ubatch(self, should_ubatch: bool) -> bool: dp_size = self.vllm_config.parallel_config.data_parallel_size @@ -1753,8 +1751,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # Create one forward context per ubatch forward_contexts = [] - for i, (_, tokens_slice) in enumerate(ubatch_slices): - num_tokens = (tokens_slice.stop - tokens_slice.start) + for i, ubatch_slice in enumerate(ubatch_slices): + num_tokens = (ubatch_slice.token_slice.stop - + ubatch_slice.token_slice.start) forward_contexts.append( create_forward_context( attn_metadata[i] if attn_metadata is not None else None, @@ -1772,17 +1771,18 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): enable_async_comms=self.parallel_config.enable_async_comms) ubatch_metadata: list[UbatchMetadata] = [] - for i, (_, tokens_slice) in enumerate(ubatch_slices): + for i, ubatch_slice in enumerate(ubatch_slices): input_ids, positions, inputs_embeds, intermediate_tensors = \ - self.model_inputs(tokens_slice, scheduler_output, is_dummy_run) + self.model_inputs( + ubatch_slice.token_slice, scheduler_output, is_dummy_run) ubatch_metadata.append( UbatchMetadata(context=ubatch_ctxs[i], input_ids=input_ids, positions=positions, inputs_embeds=inputs_embeds, intermediate_tensors=intermediate_tensors, - num_tokens=tokens_slice.stop - - tokens_slice.start)) + num_tokens=ubatch_slice.token_slice.stop - + ubatch_slice.token_slice.start)) return ubatch_metadata @@ -1808,8 +1808,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): results: list[tuple[int, torch.Tensor]] = [] compute_stream = ubatch_metadata[0].context.compute_stream - num_tokens = ubatch_metadata[0].num_tokens + ubatch_metadata[ - 1].num_tokens + num_tokens = ubatch_metadata[0].num_tokens + \ + ubatch_metadata[1].num_tokens # Ubatches will manually manage the forward context, so we override # it to None here so we can have it restored correctly later @@ -2704,10 +2704,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): dp_size, device="cpu", dtype=torch.int32) - ubatch_slices = [(slice(0, - num_reqs // 2), slice(0, num_tokens // 2)), - (slice(num_reqs // 2, num_reqs), - slice(num_tokens // 2, num_tokens))] + ubatch_slices = [ + UbatchSlice(slice(0, num_reqs // 2), slice(0, + num_tokens // 2)), + UbatchSlice(slice(num_reqs // 2, num_reqs), + slice(num_tokens // 2, num_tokens)) + ] # attn_metadata: Optional[dict[str, Any]] = None attn_metadata: Optional[PerLayerAttnMetadata] = None