diff --git a/tests/v1/attention/test_attention_splitting.py b/tests/v1/attention/test_attention_splitting.py deleted file mode 100644 index 3fc1011d5042e..0000000000000 --- a/tests/v1/attention/test_attention_splitting.py +++ /dev/null @@ -1,157 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import pytest -import torch - -from tests.v1.attention.test_attention_backends import BATCH_SPECS -from tests.v1.attention.utils import create_common_attn_metadata -from vllm.v1.attention.backends.utils import (UbatchSlice, - _make_metadata_with_slice, - slice_query_start_locs, - split_attn_metadata) - - -@pytest.fixture -def sample_query_start_loc(): - """Sample query_start_loc tensor for testing""" - return torch.tensor([0, 5, 12, 20, 35, 50]) - - -def test_basic_slice_middle(sample_query_start_loc): - """Test slicing from middle of tensor""" - req_slice = slice(1, 3) # slice from index 1 to 3 - result = slice_query_start_locs(sample_query_start_loc, req_slice) - - expected = torch.tensor([0, 7, 15]) - assert torch.equal(result, expected) - - -def test_slice_from_beginning(sample_query_start_loc): - """Test slicing from the beginning of tensor""" - req_slice = slice(0, 2) # slice from index 0 to 2 - result = slice_query_start_locs(sample_query_start_loc, req_slice) - - expected = torch.tensor([0, 5, 12]) - assert torch.equal(result, expected) - - -def test_slice_to_end(sample_query_start_loc): - """Test slicing to the end of tensor""" - req_slice = slice(3, 5) # slice from index 3 to 5 (last index) - result = slice_query_start_locs(sample_query_start_loc, req_slice) - - expected = torch.tensor([0, 15, 30]) - assert torch.equal(result, expected) - - -def test_single_element_slice(sample_query_start_loc): - """Test slice that results in single element""" - req_slice = slice(2, 3) # slice from index 2 to 3 - result = slice_query_start_locs(sample_query_start_loc, req_slice) - - expected = torch.tensor([0, 8]) - assert torch.equal(result, expected) - - -def test_full_tensor_slice(sample_query_start_loc): - """Test slicing the entire tensor""" - req_slice = slice(0, 5) # slice entire tensor - result = slice_query_start_locs(sample_query_start_loc, req_slice) - - expected = torch.tensor([0, 5, 12, 20, 35, 50]) - assert torch.equal(result, expected) - - -def test_slice_bounds_edge_cases(sample_query_start_loc): - # Test slice that goes exactly to the last element - req_slice = slice(4, 5) # Last index - result = slice_query_start_locs(sample_query_start_loc, req_slice) - - expected = torch.tensor([0, 15]) - assert torch.equal(result, expected) - - -@pytest.fixture -def small_decode_metadata(): - """Create metadata for small decode batch""" - batch_spec = BATCH_SPECS["small_decode"] - device = torch.device("cpu") - return create_common_attn_metadata(batch_spec, - block_size=16, - device=device) - - -@pytest.fixture -def large_decode_metadata(): - """Create metadata for small decode batch""" - batch_spec = BATCH_SPECS["large_decode"] - device = torch.device("cpu") - return create_common_attn_metadata(batch_spec, - block_size=16, - device=device) - - -@pytest.fixture -def mixed_small_metadata(): - """Create metadata for mixed small batch""" - batch_spec = BATCH_SPECS["mixed_small"] - device = torch.device("cpu") - return create_common_attn_metadata(batch_spec, - block_size=16, - device=device) - - -# Tests for _make_metadata_with_slice -def test_make_metadata_with_slice_decode_batch(small_decode_metadata): - """Test slicing decode batch metadata""" - # Split first request only - ubatch_slice = UbatchSlice(slice(0, 1), slice(0, 1)) - - result = _make_metadata_with_slice(ubatch_slice, small_decode_metadata) - - # Check sliced results - assert result.num_reqs == 1 # slice(0, 1) gives 1 requests - assert result.num_actual_tokens == 1 # slice(0, 1) gives 1 token - assert result.max_query_len == 1 - assert torch.equal(result.query_start_loc, torch.tensor([0, 1])) - assert torch.equal(result.seq_lens, torch.tensor([32])) - - -def test_make_metadata_with_slice_mixed_batch(mixed_small_metadata): - """Test slicing mixed batch metadata""" - ubatch_slice = UbatchSlice(slice(1, 3), - slice(1, 7)) # Requests 1-3, tokens 1-7 - - result = _make_metadata_with_slice(ubatch_slice, mixed_small_metadata) - - assert result.num_reqs == 2 # slice(1, 3) gives 2 requests - assert result.num_actual_tokens == 6 # slice(1, 7) gives 6 tokens - assert result.max_query_len == 5 - assert torch.equal(result.query_start_loc, torch.tensor([0, 1, 6])) - assert torch.equal(result.seq_lens, torch.tensor([40, 48])) - - -def test_split_attn_metadata_decode_batch(large_decode_metadata): - """Test splitting decode batch into two equal parts""" - num_tokens = large_decode_metadata.num_reqs - mid_point = num_tokens // 2 - ubatch_slices = [ - UbatchSlice(slice(0, mid_point), slice(0, mid_point)), - UbatchSlice(slice(mid_point, num_tokens), slice(mid_point, - num_tokens)), - ] - - results = split_attn_metadata(ubatch_slices, large_decode_metadata) - - assert len(results) == 2 - - # Check first split - assert results[0].num_reqs == mid_point - assert results[0].num_actual_tokens == mid_point - assert torch.equal(results[0].seq_lens, torch.tensor([2048] * mid_point)) - - # Check second split - assert results[1].num_reqs == mid_point - assert results[1].num_actual_tokens == mid_point - assert torch.equal(results[1].seq_lens, torch.tensor([2048] * mid_point)) diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index 2fb768c9d710d..68c51340a2ac4 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -76,7 +76,6 @@ def slice_query_start_locs( """ Creates a new query_start_loc that corresponds to the requests in request_slice. - Note: This function creates a new tensor to hold the new query_start_locs. This will break cudagraph compatibility. """ @@ -130,19 +129,19 @@ def _make_metadata_with_slice( def split_attn_metadata( - ubatch_slices: list[UbatchSlice], + ubatch_slices: list[tuple[slice, slice]], common_attn_metadata: CommonAttentionMetadata, ) -> list[CommonAttentionMetadata]: """ Creates a new CommonAttentionMetadata instance that corresponds to the requests for each UbatchSlice in ubatch_slices. - Note: This function does not modify common_attn_metadata """ results = [] for ubatch_slice in ubatch_slices: - results.append( - _make_metadata_with_slice(ubatch_slice, common_attn_metadata)) + s = UbatchSlice(request_slice=ubatch_slice[0], + token_slice=ubatch_slice[1]) + results.append(_make_metadata_with_slice(s, common_attn_metadata)) return results diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index dcf8cf158e307..19a8e161f29f0 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -52,7 +52,7 @@ from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata from vllm.v1.attention.backends.mamba_selectors import get_mamba_attn_backend from vllm.v1.attention.backends.utils import ( AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata, - UbatchSlice, make_kv_sharing_fast_prefill_attention_metadata, + make_kv_sharing_fast_prefill_attention_metadata, make_local_attention_virtual_batches, split_attn_metadata) from vllm.v1.core.encoder_cache_manager import compute_encoder_budget from vllm.v1.kv_cache_interface import (AttentionSpec, @@ -100,6 +100,7 @@ AttnMetadataDict: TypeAlias = dict[str, FlashAttentionMetadata] PerLayerAttnMetadata: TypeAlias = Union[list[AttnMetadataDict], AttnMetadataDict] +UbatchSlice: TypeAlias = tuple[slice, slice] UBatchSlices: TypeAlias = list[UbatchSlice] @@ -655,9 +656,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): assert b0_reqs_end < num_reqs and \ b0_tokens_end < total_num_scheduled_tokens ubatch_slices = [ - UbatchSlice(slice(0, b0_reqs_end), slice(0, b0_tokens_end)), - UbatchSlice(slice(b0_reqs_end, num_reqs), - slice(b0_tokens_end, total_num_scheduled_tokens)), + (slice(0, b0_reqs_end), slice(0, b0_tokens_end)), + (slice(b0_reqs_end, + num_reqs), slice(b0_tokens_end, + total_num_scheduled_tokens)), ] # Compute ubatch padding. This currently only accounts for DP padding @@ -1593,10 +1595,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): first_ubatch_slice = ubatch_slices[0] second_ubatch_slice = ubatch_slices[1] - first_ubatch_num_tokens = first_ubatch_slice.token_slice.stop - \ - first_ubatch_slice.token_slice.start - second_ubatch_num_tokens = second_ubatch_slice.token_slice.stop - \ - second_ubatch_slice.token_slice.start + first_ubatch_num_tokens = first_ubatch_slice[ + 1].stop - first_ubatch_slice[1].start + second_ubatch_num_tokens = second_ubatch_slice[ + 1].stop - second_ubatch_slice[1].start # We don't support prefills yet so the two ubatches should only differ # by at most one token assert abs(first_ubatch_num_tokens - second_ubatch_num_tokens) <= 1 @@ -1633,7 +1635,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # slicing but before attention meta data creation def pad_out_ubatch_first_stage(self, ubatch_slices: UBatchSlices, num_pad_tokens: int): - original_num_tokens = ubatch_slices[1].token_slice.stop + original_num_tokens = ubatch_slices[1][1].stop assert num_pad_tokens < original_num_tokens total_num_tokens_per_ubatch = (original_num_tokens + num_pad_tokens) // 2 @@ -1641,10 +1643,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): padded_second_ubatch_slice = slice(total_num_tokens_per_ubatch, original_num_tokens) - ubatch_slices[0] = UbatchSlice(padded_first_ubatch_slice, - padded_first_ubatch_slice) - ubatch_slices[1] = UbatchSlice(padded_second_ubatch_slice, - padded_second_ubatch_slice) + ubatch_slices[0] = (padded_first_ubatch_slice, + padded_first_ubatch_slice) + ubatch_slices[1] = (padded_second_ubatch_slice, + padded_second_ubatch_slice) # This is where the second ubatch is adjusted to account for the padding. # Should be called after attention metadata creation. This just pads @@ -1653,10 +1655,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): def pad_out_ubatch_second_stage(self, ubatch_slices: UBatchSlices, num_total_tokens: int): # TODO Add asserts to make sure stage one ran - padded_second_ubatch_slice = slice(ubatch_slices[1].token_slice.start, + padded_second_ubatch_slice = slice(ubatch_slices[1][1].start, num_total_tokens) - ubatch_slices[1] = UbatchSlice(padded_second_ubatch_slice, - padded_second_ubatch_slice) + ubatch_slices[1] = (padded_second_ubatch_slice, + padded_second_ubatch_slice) def should_ubatch(self, should_ubatch: bool) -> bool: dp_size = self.vllm_config.parallel_config.data_parallel_size @@ -1751,9 +1753,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # Create one forward context per ubatch forward_contexts = [] - for i, ubatch_slice in enumerate(ubatch_slices): - num_tokens = (ubatch_slice.token_slice.stop - - ubatch_slice.token_slice.start) + for i, (_, tokens_slice) in enumerate(ubatch_slices): + num_tokens = (tokens_slice.stop - tokens_slice.start) forward_contexts.append( create_forward_context( attn_metadata[i] if attn_metadata is not None else None, @@ -1771,18 +1772,17 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): enable_async_comms=self.parallel_config.enable_async_comms) ubatch_metadata: list[UbatchMetadata] = [] - for i, ubatch_slice in enumerate(ubatch_slices): + for i, (_, tokens_slice) in enumerate(ubatch_slices): input_ids, positions, inputs_embeds, intermediate_tensors = \ - self.model_inputs( - ubatch_slice.token_slice, scheduler_output, is_dummy_run) + self.model_inputs(tokens_slice, scheduler_output, is_dummy_run) ubatch_metadata.append( UbatchMetadata(context=ubatch_ctxs[i], input_ids=input_ids, positions=positions, inputs_embeds=inputs_embeds, intermediate_tensors=intermediate_tensors, - num_tokens=ubatch_slice.token_slice.stop - - ubatch_slice.token_slice.start)) + num_tokens=tokens_slice.stop - + tokens_slice.start)) return ubatch_metadata @@ -1808,8 +1808,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): results: list[tuple[int, torch.Tensor]] = [] compute_stream = ubatch_metadata[0].context.compute_stream - num_tokens = ubatch_metadata[0].num_tokens + \ - ubatch_metadata[1].num_tokens + num_tokens = ubatch_metadata[0].num_tokens + ubatch_metadata[ + 1].num_tokens # Ubatches will manually manage the forward context, so we override # it to None here so we can have it restored correctly later @@ -2704,12 +2704,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): dp_size, device="cpu", dtype=torch.int32) - ubatch_slices = [ - UbatchSlice(slice(0, num_reqs // 2), slice(0, - num_tokens // 2)), - UbatchSlice(slice(num_reqs // 2, num_reqs), - slice(num_tokens // 2, num_tokens)) - ] + ubatch_slices = [(slice(0, + num_reqs // 2), slice(0, num_tokens // 2)), + (slice(num_reqs // 2, num_reqs), + slice(num_tokens // 2, num_tokens))] # attn_metadata: Optional[dict[str, Any]] = None attn_metadata: Optional[PerLayerAttnMetadata] = None