diff --git a/tests/v1/core/test_kv_cache_utils.py b/tests/v1/core/test_kv_cache_utils.py index 1cdc80dd3546c..e572100fe7a18 100644 --- a/tests/v1/core/test_kv_cache_utils.py +++ b/tests/v1/core/test_kv_cache_utils.py @@ -19,7 +19,8 @@ from vllm.v1.core.kv_cache_utils import (NONE_HASH, BlockHashType, hash_request_tokens, unify_kv_cache_configs) from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, - KVCacheGroupSpec, KVCacheTensor) + KVCacheGroupSpec, KVCacheTensor, + SlidingWindowSpec) from vllm.v1.metrics.stats import PrefixCacheStats from vllm.v1.request import Request @@ -54,12 +55,14 @@ def new_kv_cache_spec(block_size=16, num_kv_heads=2, head_size=64, dtype=torch.float32, - use_mla=False): + use_mla=False, + sliding_window=None): return FullAttentionSpec(block_size=block_size, num_kv_heads=num_kv_heads, head_size=head_size, dtype=dtype, - use_mla=use_mla) + use_mla=use_mla, + sliding_window=sliding_window) def test_none_hash(): @@ -471,6 +474,68 @@ def test_unify_kv_cache_configs(): unify_kv_cache_configs(diff_kv_cache_config) +def test_merge_kv_cache_spec(): + same_layer_specs = [ + new_kv_cache_spec(num_kv_heads=32), + new_kv_cache_spec(num_kv_heads=32), + ] + merged_layer_spec = same_layer_specs[0].merge(same_layer_specs) + assert merged_layer_spec.block_size == 16 + assert merged_layer_spec.num_kv_heads == 32 + assert merged_layer_spec.head_size == 64 + assert merged_layer_spec.dtype == torch.float32 + assert merged_layer_spec.sliding_window is None + + different_layer_specs = [ + new_kv_cache_spec(num_kv_heads=32), + new_kv_cache_spec(num_kv_heads=16), + ] + with pytest.raises(AssertionError): + different_layer_specs[0].merge(different_layer_specs) + + full_spec = new_kv_cache_spec(num_kv_heads=32) + different_type_layer_specs = [ + full_spec, + SlidingWindowSpec( + block_size=full_spec.block_size, + num_kv_heads=full_spec.num_kv_heads, + head_size=full_spec.head_size, + dtype=full_spec.dtype, + use_mla=full_spec.use_mla, + sliding_window=1, + ), + ] + with pytest.raises(AssertionError): + different_type_layer_specs[0].merge(different_type_layer_specs) + with pytest.raises(AssertionError): + different_type_layer_specs[1].merge(different_type_layer_specs) + + different_sliding_window_layer_specs = [ + new_kv_cache_spec(num_kv_heads=32), + new_kv_cache_spec(num_kv_heads=32, sliding_window=1), + new_kv_cache_spec(num_kv_heads=32, sliding_window=2), + ] + with pytest.raises(ValueError): + different_sliding_window_layer_specs[0].merge( + different_sliding_window_layer_specs) + + same_sliding_window_layer_specs = [ + new_kv_cache_spec(num_kv_heads=32, sliding_window=1), + new_kv_cache_spec(num_kv_heads=32, sliding_window=1), + ] + merged_layer_spec = same_sliding_window_layer_specs[0].merge( + same_sliding_window_layer_specs) + assert merged_layer_spec.sliding_window == 1 + + same_sliding_window_layer_spec_with_none = [ + new_kv_cache_spec(num_kv_heads=32, sliding_window=1), + new_kv_cache_spec(num_kv_heads=32, sliding_window=None), + ] + merged_layer_spec = same_sliding_window_layer_spec_with_none[0].merge( + same_sliding_window_layer_spec_with_none) + assert merged_layer_spec.sliding_window == 1 + + @pytest.mark.parametrize( ("model_id", "max_model_len", "want_estimated_max_len"), [ ("Qwen/Qwen1.5-7B", 16385, 16384), diff --git a/tests/v1/core/test_prefix_caching.py b/tests/v1/core/test_prefix_caching.py index 2d7411381e160..3da27786b1f2f 100644 --- a/tests/v1/core/test_prefix_caching.py +++ b/tests/v1/core/test_prefix_caching.py @@ -84,7 +84,7 @@ def test_prefill(hash_algo): blocks = manager.allocate_slots(req0, 55, len(computed_blocks.blocks) * 16, computed_blocks) - assert blocks.get_block_ids() == [1, 2, 3, 4] + assert blocks.get_block_ids() == [[1, 2, 3, 4]] # Check full block metadata parent_block_hash = None @@ -107,13 +107,13 @@ def test_prefill(hash_algo): req1 = make_request("1", common_token_ids + unique_token_ids) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) assert len(manager.req_to_block_hashes[req1.request_id]) == 3 - assert computed_blocks.get_block_ids() == [1, 2, 3] + assert computed_blocks.get_block_ids() == [[1, 2, 3]] assert num_computed_tokens == 3 * 16 num_new_tokens = 53 - 3 * 16 blocks = manager.allocate_slots(req1, num_new_tokens, len(computed_blocks.blocks) * 16, computed_blocks) - assert blocks.get_block_ids() == [5] + assert blocks.get_block_ids() == [[5]] for block in computed_blocks.blocks: assert block.ref_cnt == 2 @@ -141,13 +141,13 @@ def test_prefill(hash_algo): req2 = make_request("2", common_token_ids + unique_token_ids) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) assert len(manager.req_to_block_hashes[req2.request_id]) == 3 - assert computed_blocks.get_block_ids() == [1, 2, 3] + assert computed_blocks.get_block_ids() == [[1, 2, 3]] assert num_computed_tokens == 3 * 16 num_new_tokens = 53 - 3 * 16 blocks = manager.allocate_slots(req2, num_new_tokens, len(computed_blocks.blocks) * 16, computed_blocks) - assert blocks.get_block_ids() == [6] + assert blocks.get_block_ids() == [[6]] # Although we only have 6 free blocks, we have 8 blocks in # the free block queue due to lazy removal. @@ -171,7 +171,7 @@ def test_prefill(hash_algo): len(computed_blocks.blocks) * 16, computed_blocks) # This block ID order also checks the eviction order. - assert blocks.get_block_ids() == [7, 8, 9, 10, 4, 5, 6, 3, 2, 1] + assert blocks.get_block_ids() == [[7, 8, 9, 10, 4, 5, 6, 3, 2, 1]] assert manager.block_pool.free_block_queue.num_free_blocks == 0 assert manager.block_pool.free_block_queue.free_list_head is None assert manager.block_pool.free_block_queue.free_list_tail is None @@ -208,7 +208,7 @@ def test_prefill_plp(): blocks = manager.allocate_slots(req0, 55, len(computed_blocks.blocks) * 16, computed_blocks) - assert blocks.get_block_ids() == [1, 2, 3, 4] + assert blocks.get_block_ids() == [[1, 2, 3, 4]] req0_block_hashes = [b.block_hash for b in blocks.blocks] # Check full block metadata @@ -233,13 +233,13 @@ def test_prefill_plp(): req1 = make_request("1", common_token_ids + unique_token_ids) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) assert len(manager.req_to_block_hashes[req1.request_id]) == 3 - assert computed_blocks.get_block_ids() == [1, 2, 3] + assert computed_blocks.get_block_ids() == [[1, 2, 3]] assert num_computed_tokens == 3 * 16 num_new_tokens = 53 - 3 * 16 blocks = manager.allocate_slots(req1, num_new_tokens, len(computed_blocks.blocks) * 16, computed_blocks) - assert blocks.get_block_ids() == [5] + assert blocks.get_block_ids() == [[5]] for block in computed_blocks.blocks: assert block.ref_cnt == 2 @@ -277,11 +277,11 @@ def test_prefill_plp(): block_ids = blocks.get_block_ids() # Duplicate cached blocks have different ids but same hashes vs request #0 assert [b.block_hash for b in blocks.blocks] == req0_block_hashes - assert block_ids != [1, 2, 3, 4] + assert block_ids != [[1, 2, 3, 4]] # Request #2 block hashes are valid since request #0 hashes are. # Check block reference counts. - for block_id in block_ids: + for block_id in block_ids[0]: assert manager.block_pool.blocks[block_id].ref_cnt == 1 manager.free(req2) @@ -307,7 +307,7 @@ def test_decode(): blocks = manager.allocate_slots(req0, 55, len(computed_blocks.blocks) * 16, computed_blocks) - assert blocks.get_block_ids() == [1, 2, 3, 4] + assert blocks.get_block_ids() == [[1, 2, 3, 4]] # Append slots without allocating a new block. req0.num_computed_tokens = 55 @@ -379,12 +379,12 @@ def test_evict(): # Touch the first 2 blocks. req2 = make_request("2", list(range(2 * 16 + 3))) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) - assert computed_blocks.get_block_ids() == [1, 2] + assert computed_blocks.get_block_ids() == [[1, 2]] assert num_computed_tokens == 2 * 16 blocks = manager.allocate_slots(req2, 3, len(computed_blocks.blocks) * 16, computed_blocks) - assert blocks.get_block_ids() == [10] + assert blocks.get_block_ids() == [[10]] assert manager.block_pool.free_block_queue.num_free_blocks == 7 @@ -625,7 +625,7 @@ def test_mm_prefix_caching(): blocks = manager.allocate_slots(req0, 59, len(computed_blocks.blocks) * 16, computed_blocks) - assert blocks.get_block_ids() == [1, 2, 3, 4] + assert blocks.get_block_ids() == [[1, 2, 3, 4]] req0.num_computed_tokens = 59 # Append slots without allocating a new block. @@ -686,7 +686,7 @@ def test_cache_key_salting(): blocks = manager.allocate_slots(req0, 59, len(computed_blocks.blocks) * 16, computed_blocks) - assert blocks.get_block_ids() == [1, 2, 3, 4] + assert blocks.get_block_ids() == [[1, 2, 3, 4]] req0.num_computed_tokens = 59 # Append slots without allocating a new block. @@ -797,7 +797,7 @@ def test_reset_prefix_cache(): all_token_ids = full_block_token_ids + unique_token_ids req0 = make_request("0", all_token_ids) blocks = manager.allocate_slots(req0, 55) - assert blocks.get_block_ids() == [1, 2, 3, 4] + assert blocks.get_block_ids() == [[1, 2, 3, 4]] unique_token_ids = [4] * 7 all_token_ids = full_block_token_ids + unique_token_ids @@ -808,7 +808,7 @@ def test_reset_prefix_cache(): blocks = manager.allocate_slots(req1, 7, len(computed_blocks.blocks) * 16, computed_blocks) - assert blocks.get_block_ids() == [5] + assert blocks.get_block_ids() == [[5]] # Failed to reset prefix cache because some blocks are not freed yet. assert not manager.reset_prefix_cache() diff --git a/tests/v1/worker/test_gpu_input_batch.py b/tests/v1/worker/test_gpu_input_batch.py index 7b1359c8576f7..638f5bedcfcac 100644 --- a/tests/v1/worker/test_gpu_input_batch.py +++ b/tests/v1/worker/test_gpu_input_batch.py @@ -9,9 +9,11 @@ import torch from vllm.sampling_params import SamplingParams from vllm.utils import is_pin_memory_available, make_tensor_with_pad +from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, + KVCacheGroupSpec, KVCacheTensor) from vllm.v1.sample.metadata import SamplingMetadata -from vllm.v1.worker.gpu_input_batch import (BlockTable, CachedRequestState, - InputBatch) +from vllm.v1.worker.block_table import BlockTable, MultiGroupBlockTable +from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch VOCAB_SIZE = 1024 NUM_OUTPUT_TOKENS = 20 @@ -22,6 +24,27 @@ CUDA_DEVICES = [ MAX_NUM_PROMPT_TOKENS = 64 +def get_kv_cache_config() -> KVCacheConfig: + return KVCacheConfig( + num_blocks=10, + tensors={ + "layer.0": KVCacheTensor(size=1024), + }, + kv_cache_groups=[ + KVCacheGroupSpec( + layer_names=["layer.0"], + kv_cache_spec=FullAttentionSpec( + block_size=1, + num_kv_heads=1, + head_size=16, + dtype=torch.float16, + use_mla=False, + ), + ), + ], + ) + + def _compare_objs(obj1, obj2): attrs = inspect.getmembers(obj1, lambda a: not (inspect.isroutine(a))) attr_names = set([ @@ -41,6 +64,10 @@ def _compare_objs(obj1, obj2): elif isinstance(a, np.ndarray): if np.allclose(a, b): is_same = True + elif isinstance(a, MultiGroupBlockTable): + for a_i, b_i in zip(a.block_tables, b.block_tables): + _compare_objs(a_i, b_i) + is_same = True elif isinstance(a, (BlockTable, SamplingMetadata)): _compare_objs(a, b) is_same = True # if we make it here must be same @@ -198,7 +225,7 @@ def _construct_cached_request_state(req_id_suffix: int): sampling_params=_create_sampling_params(), mm_inputs=[], mm_positions=[], - block_ids=[], + block_ids=[[]], generator=None, num_computed_tokens=len(output_token_ids), output_token_ids=output_token_ids, @@ -220,11 +247,11 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int): input_batch: InputBatch = InputBatch( max_num_reqs=batch_size, max_model_len=1024, - max_num_blocks_per_req=10, max_num_batched_tokens=1024, device=torch.device(device), pin_memory=is_pin_memory_available(), vocab_size=1024, + kv_cache_config=get_kv_cache_config(), ) reqs: list[CachedRequestState] = [] req_id_reqs = {} @@ -310,20 +337,20 @@ def test_swap_states_in_input_batch(device: str, batch_size: int, input_batch: InputBatch = InputBatch( max_num_reqs=batch_size, max_model_len=1024, - max_num_blocks_per_req=10, max_num_batched_tokens=1024, device=torch.device(device), pin_memory=is_pin_memory_available(), vocab_size=1024, + kv_cache_config=get_kv_cache_config(), ) ref_input_batch: InputBatch = InputBatch( max_num_reqs=batch_size, max_model_len=1024, - max_num_blocks_per_req=10, max_num_batched_tokens=1024, device=torch.device(device), pin_memory=is_pin_memory_available(), vocab_size=1024, + kv_cache_config=get_kv_cache_config(), ) reqs: list[CachedRequestState] = [] diff --git a/tests/v1/worker/test_gpu_model_runner.py b/tests/v1/worker/test_gpu_model_runner.py index 725747294fd8d..e44660525763c 100644 --- a/tests/v1/worker/test_gpu_model_runner.py +++ b/tests/v1/worker/test_gpu_model_runner.py @@ -1,15 +1,16 @@ # SPDX-License-Identifier: Apache-2.0 -import weakref import pytest -import torch -from vllm.config import CacheConfig, ModelConfig, SchedulerConfig, VllmConfig +from vllm.config import (CacheConfig, ModelConfig, ParallelConfig, + SchedulerConfig, VllmConfig) from vllm.sampling_params import SamplingParams from vllm.v1.core.sched.output import (CachedRequestData, NewRequestData, SchedulerOutput) -from vllm.v1.kv_cache_interface import FullAttentionSpec +from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, + KVCacheGroupSpec, KVCacheTensor) from vllm.v1.sample.metadata import SamplingMetadata +from vllm.v1.worker.gpu_input_batch import InputBatch from vllm.v1.worker.gpu_model_runner import GPUModelRunner @@ -17,13 +18,34 @@ def initialize_kv_cache(runner: GPUModelRunner): """ Only perform necessary steps in GPUModelRunner.initialize_kv_cache() """ - kv_cache_spec = FullAttentionSpec(block_size=16, - num_kv_heads=1, - head_size=64, - dtype=torch.float16, - use_mla=False) - runner.attn_metadata_builder = runner.attn_backend.get_builder_cls()( - weakref.proxy(runner), kv_cache_spec, runner.input_batch.block_table) + kv_cache_config = KVCacheConfig( + num_blocks=10, + tensors={ + "layer.0": KVCacheTensor(size=1024), + }, + kv_cache_groups=[ + KVCacheGroupSpec( + layer_names=["layer.0"], + kv_cache_spec=FullAttentionSpec( + block_size=16, + num_kv_heads=runner.model_config.get_num_kv_heads( + runner.parallel_config), + head_size=runner.model_config.get_head_size(), + dtype=runner.kv_cache_dtype, + use_mla=False, + )) + ]) + runner.kv_cache_config = kv_cache_config + runner.input_batch = InputBatch( + max_num_reqs=runner.max_num_reqs, + max_model_len=runner.max_model_len, + max_num_batched_tokens=runner.max_num_tokens, + device=runner.device, + pin_memory=runner.pin_memory, + vocab_size=runner.model_config.get_vocab_size(), + kv_cache_config=kv_cache_config, + ) + runner.initialize_attn_backend(kv_cache_config) @pytest.fixture @@ -48,10 +70,12 @@ def model_runner(): swap_space=0, cache_dtype="auto", ) + parallel_config = ParallelConfig() vllm_config = VllmConfig( model_config=model_config, cache_config=cache_config, scheduler_config=scheduler_config, + parallel_config=parallel_config, ) device = "cuda" @@ -73,7 +97,7 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput: mm_hashes=[], mm_positions=[], sampling_params=SamplingParams(), - block_ids=[0], + block_ids=[[0]], num_computed_tokens=0, lora_request=None, )) @@ -111,13 +135,14 @@ def _is_sampling_metadata_changed(model_runner, def _is_req_state_block_table_match(model_runner, req_id: str) -> bool: req_index = model_runner.input_batch.req_id_to_index[req_id] - block_table = model_runner.input_batch.block_table + block_table = model_runner.input_batch.block_table[0] req_state = model_runner.requests[req_id] - if block_table.num_blocks_per_row[req_index] != len(req_state.block_ids): + if block_table.num_blocks_per_row[req_index] != len( + req_state.block_ids[0]): return False num_blocks = block_table.num_blocks_per_row[req_index] return (block_table.block_table_np[req_index, :num_blocks] == - req_state.block_ids).all() + req_state.block_ids[0]).all() def test_update_states_new_request(model_runner): @@ -200,7 +225,7 @@ def test_update_states_request_resumed(model_runner): req_id=req_id, resumed_from_preemption=False, new_token_ids=[], - new_block_ids=[], + new_block_ids=[[]], num_computed_tokens=0, ) diff --git a/tests/weight_loading/models.txt b/tests/weight_loading/models.txt index 1b797074096ed..9164f8595346e 100644 --- a/tests/weight_loading/models.txt +++ b/tests/weight_loading/models.txt @@ -2,7 +2,7 @@ gptq_marlin, robertgshaw2/zephyr-7b-beta-channelwise-gptq, main gptq_marlin, TheBloke/Llama-2-7B-GPTQ, main gptq_marlin, TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ, main gptq_marlin, TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ, gptq-8bit--1g-actorder_True -gptq_marlin, TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ, gptq-8bit-32g-actorder_True +#gptq_marlin, TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ, gptq-8bit-32g-actorder_True gptq_marlin, TechxGenus/gemma-1.1-2b-it-GPTQ, main gptq, robertgshaw2/zephyr-7b-beta-channelwise-gptq, main gptq, TheBloke/Llama-2-7B-GPTQ, main diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py index 0fedb6fd5ed92..0421a65a2c819 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py @@ -288,7 +288,7 @@ class SharedStorageConnector(KVConnectorBase_V1): for new_req in scheduler_output.scheduled_new_reqs: if new_req.req_id in self._requests_need_load: meta.add_request(token_ids=new_req.prompt_token_ids, - block_ids=new_req.block_ids, + block_ids=new_req.block_ids[0], block_size=self._block_size, is_store=False) total_need_load += 1 @@ -299,7 +299,7 @@ class SharedStorageConnector(KVConnectorBase_V1): # the original prompt tokens. if not self._found_match_for_request(new_req): meta.add_request(token_ids=new_req.prompt_token_ids, - block_ids=new_req.block_ids, + block_ids=new_req.block_ids[0], block_size=self._block_size, is_store=True) @@ -319,7 +319,7 @@ class SharedStorageConnector(KVConnectorBase_V1): # NOTE(rob): For resumed req, new_block_ids is all # of the block_ids for the request. - block_ids = cached_req.new_block_ids + block_ids = cached_req.new_block_ids[0] meta.add_request(token_ids=token_ids, block_ids=block_ids, diff --git a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py index 3abb185c5b8fe..7ce39110ac01d 100644 --- a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py +++ b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py @@ -67,13 +67,13 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]): max_model_len = self.runner.model_config.max_model_len assert max_model_len == 32768,\ "AITER MLA requires max_model_len=32768" - assert self.runner.block_size == 1, "AITER MLA" \ + assert self.kv_cache_spec.block_size == 1, "AITER MLA" \ "only supports block size 1." def _get_paged_kv_tensors( self, block_table: torch.Tensor, seq_lens: torch.Tensor) -> tuple[torch.Tensor, ...]: - page_size = self.runner.block_size + page_size = self.kv_cache_spec.block_size block_table_bounds = (seq_lens + page_size - 1) // page_size mask = (torch.arange(block_table.size(1), diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index 598fc871110e3..da18ece7555a2 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -32,9 +32,16 @@ class KVCacheBlocks: """Creates a new KVCacheBlocks instance with no blocks.""" return cls([]) - def get_block_ids(self) -> list[int]: - """Converts the KVCacheBlocks instance to a list of block IDs.""" - return [block.block_id for block in self.blocks] + def get_block_ids(self) -> list[list[int]]: + """ + Converts the KVCacheBlocks instance to block_ids. + + Returns: + list[list[int]]: A two-level list where + * the outer list corresponds to KV cache groups (only 1 group now) + * each inner list contains the block_ids of the blocks in that group + """ + return [[block.block_id for block in self.blocks]] def get_unhashed_block_ids(self) -> list[int]: """Get block_ids of unhashed blocks from KVCacheBlocks instance.""" @@ -300,9 +307,9 @@ class KVCacheManager: self, request: Request, num_running_requests: int, - ) -> int: + ) -> list[int]: """Calculate the number of common prefix blocks shared by all requests - in the RUNNING state. + in the RUNNING state for each kv cache group. The function determines this by selecting any request and iterating through its blocks. A block is considered a common prefix block if its @@ -332,11 +339,14 @@ class KVCacheManager: requests in the current step. Returns: - int: The number of common prefix blocks. + list[int]: The number of common prefix blocks for each kv cache + group. """ assert request.status == RequestStatus.RUNNING - return self.single_type_manager.get_num_common_prefix_blocks( - request.request_id, num_running_requests) + return [ + self.single_type_manager.get_num_common_prefix_blocks( + request.request_id, num_running_requests) + ] def free_block_hashes(self, request: Request) -> None: """Discard the block hashes for the request. @@ -354,10 +364,8 @@ class KVCacheManager: """ return self.block_pool.take_events() - def get_block_ids(self, request_id: str) -> list[int]: + def get_block_ids(self, request_id: str) -> list[list[int]]: """Get the block ids of a request.""" assert request_id in self.single_type_manager.req_to_blocks - return [ - block.block_id - for block in self.single_type_manager.req_to_blocks[request_id] - ] + return KVCacheBlocks(self.single_type_manager.req_to_blocks[request_id] + ).get_block_ids() diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index 27c5158350878..403b5401be75a 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -577,14 +577,12 @@ def create_kv_cache_group_specs( """ kv_cache_groups = [] for layer_names_one_group in grouped_layer_names: - layer_spec = kv_cache_spec[layer_names_one_group[0]] - assert all( - kv_cache_spec[layer_name] == layer_spec - for layer_name in layer_names_one_group[1:]), ( - "All layers in the same KV cache group must share the same " - "KVCacheSpec.") + layer_specs = [ + kv_cache_spec[layer_name] for layer_name in layer_names_one_group + ] + merged_layer_spec = layer_specs[0].merge(layer_specs) kv_cache_groups.append( - KVCacheGroupSpec(layer_names_one_group, layer_spec)) + KVCacheGroupSpec(layer_names_one_group, merged_layer_spec)) return kv_cache_groups @@ -683,6 +681,7 @@ def unify_hybrid_kv_cache_specs(kv_cache_spec: dict[str, KVCacheSpec]): head_size=spec.head_size, dtype=spec.dtype, use_mla=spec.use_mla, + sliding_window=spec.sliding_window, ) diff --git a/vllm/v1/core/sched/output.py b/vllm/v1/core/sched/output.py index 24032498e50ba..2572344309837 100644 --- a/vllm/v1/core/sched/output.py +++ b/vllm/v1/core/sched/output.py @@ -26,7 +26,7 @@ class NewRequestData: mm_hashes: list[str] mm_positions: list[PlaceholderRange] sampling_params: SamplingParams - block_ids: list[int] + block_ids: list[list[int]] num_computed_tokens: int lora_request: Optional[LoRARequest] @@ -34,7 +34,7 @@ class NewRequestData: def from_request( cls, request: Request, - block_ids: list[int], + block_ids: list[list[int]], ) -> NewRequestData: return cls( req_id=request.request_id, @@ -85,7 +85,7 @@ class CachedRequestData: # request's block IDs instead of appending to the existing block IDs. resumed_from_preemption: bool new_token_ids: list[int] - new_block_ids: list[int] + new_block_ids: list[list[int]] num_computed_tokens: int @classmethod @@ -94,7 +94,7 @@ class CachedRequestData: request: Request, resumed_from_preemption: bool, new_token_ids: list[int], - new_block_ids: list[int], + new_block_ids: list[list[int]], ) -> CachedRequestData: return cls( req_id=request.request_id, @@ -131,9 +131,9 @@ class SchedulerOutput: # E.g., if a request has [0, 1], it could mean the vision encoder needs # to process that the request's 0-th and 1-th images in the current step. scheduled_encoder_inputs: dict[str, list[int]] - # Number of common prefix blocks for all requests. + # Number of common prefix blocks for all requests in each KV cache group. # This can be used for cascade attention. - num_common_prefix_blocks: int + num_common_prefix_blocks: list[int] # Request IDs that are finished in between the previous and the current # steps. This is used to notify the workers about the finished requests diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 96313c288f7d0..5ad05485e8f33 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -173,7 +173,7 @@ class Scheduler(SchedulerInterface): # uses structured decoding. structured_output_request_ids: dict[str, int] = {} - req_to_new_block_ids: dict[str, list[int]] = {} + req_to_new_block_ids: dict[str, list[list[int]]] = {} num_scheduled_tokens: dict[str, int] = {} token_budget = self.max_num_scheduled_tokens # Encoder-related. @@ -477,7 +477,8 @@ class Scheduler(SchedulerInterface): # Get the longest common prefix among all requests in the running queue. # This can be potentially used for cascade attention. - num_common_prefix_blocks = 0 + num_common_prefix_blocks = [0] * len( + self.kv_cache_config.kv_cache_groups) if self.running: any_request = self.running[0] num_common_prefix_blocks = ( @@ -564,7 +565,7 @@ class Scheduler(SchedulerInterface): request: Request, num_scheduled_tokens: int, num_scheduled_spec_tokens: int, - new_block_ids: list[int], + new_block_ids: list[list[int]], resumed_from_preemption: bool, ) -> CachedRequestData: # OPTIMIZATION: Cache the CachedRequestData objects to avoid creating @@ -939,7 +940,9 @@ class Scheduler(SchedulerInterface): """ if self.connector is None: return False, None - block_ids = self.kv_cache_manager.get_block_ids(request.request_id) + assert len(self.kv_cache_config.kv_cache_groups + ) == 1, "KV connector only supports one KV cache group now" + block_ids = self.kv_cache_manager.get_block_ids(request.request_id)[0] return self.connector.request_finished(request, block_ids) def _update_waiting_for_remote_kv(self, request: Request) -> bool: @@ -956,9 +959,10 @@ class Scheduler(SchedulerInterface): """ if request.request_id not in self.finished_recving_kv_req_ids: return False - + assert len(self.kv_cache_config.kv_cache_groups + ) == 1, "KV connector only supports one KV cache group now" # Now that the blocks are ready, actually cache them. - block_ids = self.kv_cache_manager.get_block_ids(request.request_id) + block_ids = self.kv_cache_manager.get_block_ids(request.request_id)[0] num_computed_tokens = len(block_ids) * self.block_size if num_computed_tokens == request.num_tokens: num_computed_tokens -= 1 diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py index 4fc0844cd1f4d..2747fc7fabd1e 100644 --- a/vllm/v1/kv_cache_interface.py +++ b/vllm/v1/kv_cache_interface.py @@ -1,8 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 +import copy from dataclasses import dataclass +from typing import Optional import torch +from typing_extensions import Self from vllm.config import VllmConfig from vllm.logger import init_logger @@ -53,6 +56,16 @@ class KVCacheSpec: """ raise NotImplementedError + @classmethod + def merge(cls, specs: list[Self]) -> Self: + """ + Merge a list of KVCacheSpec objects into a single KVCacheSpec object. + """ + assert all(spec.type_id == specs[0].type_id for spec in specs[1:]), ( + "All layers in the same KV cache group must share the same " + "type_id.") + return copy.deepcopy(specs[0]) + @dataclass class AttentionSpec(KVCacheSpec): @@ -71,6 +84,16 @@ class AttentionSpec(KVCacheSpec): @dataclass class FullAttentionSpec(AttentionSpec): + sliding_window: Optional[int] = None + """ + When hybrid allocator is disabled and the model contains both full + attention layers and sliding window attention layers, sliding + window attention are regarded as full attention in KV cache manager + (blocks are allocated for all tokens), while computed as sliding window + attention in model runner. + In this case, we use FullAttentionSpec and record the sliding window size. + Default to None for not using sliding window attention. + """ @property def type_id(self) -> str: @@ -80,6 +103,25 @@ class FullAttentionSpec(AttentionSpec): max_model_len = vllm_config.model_config.max_model_len return cdiv(max_model_len, self.block_size) * self.page_size_bytes + @classmethod + def merge(cls, specs: list[Self]) -> Self: + """ + Merge a list of FullAttentionSpec objects into a single + FullAttentionSpec object. + """ + merged_spec = super().merge(specs) + sliding_window = set(spec.sliding_window for spec in specs + if spec.sliding_window is not None) + if len(sliding_window) == 0: + merged_spec.sliding_window = None + elif len(sliding_window) == 1: + merged_spec.sliding_window = sliding_window.pop() + else: + raise ValueError( + "All sliding window layers in the same KV cache group " + "must have the same window size.") + return merged_spec + @dataclass class SlidingWindowSpec(AttentionSpec): diff --git a/vllm/v1/worker/block_table.py b/vllm/v1/worker/block_table.py index 581d3d9bd11b5..0c3341691509f 100644 --- a/vllm/v1/worker/block_table.py +++ b/vllm/v1/worker/block_table.py @@ -4,6 +4,8 @@ import numpy as np import torch from vllm.logger import init_logger +from vllm.utils import cdiv +from vllm.v1.kv_cache_interface import KVCacheConfig logger = init_logger(__name__) @@ -96,3 +98,48 @@ class BlockTable: def get_numpy_array(self) -> np.ndarray: """Returns the numpy array of the block table.""" return self.block_table_np + + +class MultiGroupBlockTable: + """The BlockTables for each KV cache group.""" + + def __init__(self, max_num_reqs: int, max_model_len: int, + max_num_batched_tokens: int, pin_memory: bool, + device: torch.device, kv_cache_config: KVCacheConfig) -> None: + max_num_blocks_per_req = [ + cdiv(max_model_len, g.kv_cache_spec.block_size) + for g in kv_cache_config.kv_cache_groups + ] + self.block_tables = [ + BlockTable(max_num_reqs, max_num_blocks_per_req[i], + max_num_batched_tokens, pin_memory, device) + for i in range(len(kv_cache_config.kv_cache_groups)) + ] + + def append_row(self, block_ids: list[list[int]], row_idx: int) -> None: + for i, block_table in enumerate(self.block_tables): + block_table.append_row(block_ids[i], row_idx) + + def add_row(self, block_ids: list[list[int]], row_idx: int) -> None: + for i, block_table in enumerate(self.block_tables): + block_table.add_row(block_ids[i], row_idx) + + def move_row(self, src: int, tgt: int) -> None: + for block_table in self.block_tables: + block_table.move_row(src, tgt) + + def swap_row(self, src: int, tgt: int) -> None: + for block_table in self.block_tables: + block_table.swap_row(src, tgt) + + def commit(self, num_reqs: int) -> None: + for block_table in self.block_tables: + block_table.commit(num_reqs) + + def clear(self) -> None: + for block_table in self.block_tables: + block_table.clear() + + def __getitem__(self, idx: int) -> "BlockTable": + """Returns the BlockTable for the i-th KV cache group.""" + return self.block_tables[idx] diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index 871654fca3660..570de9bddd290 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -11,10 +11,11 @@ from vllm.lora.request import LoRARequest from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange from vllm.sampling_params import SamplingParams, SamplingType from vllm.utils import swap_dict_values +from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.outputs import LogprobsTensors from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.utils import copy_slice -from vllm.v1.worker.block_table import BlockTable +from vllm.v1.worker.block_table import MultiGroupBlockTable _SAMPLING_EPS = 1e-5 @@ -29,7 +30,7 @@ class CachedRequestState: sampling_params: SamplingParams generator: Optional[torch.Generator] - block_ids: list[int] + block_ids: list[list[int]] num_computed_tokens: int output_token_ids: list[int] @@ -58,15 +59,14 @@ class InputBatch: self, max_num_reqs: int, max_model_len: int, - max_num_blocks_per_req: int, max_num_batched_tokens: int, device: torch.device, pin_memory: bool, vocab_size: int, + kv_cache_config: KVCacheConfig, ): self.max_num_reqs = max_num_reqs self.max_model_len = max_model_len - self.max_num_blocks_per_req = max_num_blocks_per_req self.max_num_batched_tokens = max_num_batched_tokens self.device = device self.pin_memory = pin_memory @@ -99,12 +99,13 @@ class InputBatch: self.num_computed_tokens_cpu_tensor.numpy() # Block table. - self.block_table = BlockTable( + self.block_table = MultiGroupBlockTable( max_num_reqs=max_num_reqs, - max_num_blocks_per_req=max_num_blocks_per_req, + max_model_len=max_model_len, max_num_batched_tokens=max_num_batched_tokens, pin_memory=pin_memory, device=device, + kv_cache_config=kv_cache_config, ) # Sampling-related. diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 1b16f273a6de3..1b34a9fb06163 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -12,6 +12,8 @@ import torch.distributed import torch.nn as nn from vllm.attention import AttentionType, get_attn_backend +from vllm.attention.backends.abstract import (AttentionBackend, + AttentionMetadataBuilder) from vllm.attention.layer import Attention from vllm.attention.utils.fa_utils import get_flash_attn_version from vllm.config import (CompilationLevel, VllmConfig, @@ -31,8 +33,8 @@ from vllm.multimodal.utils import group_mm_inputs_by_modality from vllm.sampling_params import SamplingType from vllm.sequence import IntermediateTensors from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, - GiB_bytes, LayerBlockType, LazyLoader, cdiv, - check_use_alibi, is_pin_memory_available) + GiB_bytes, LazyLoader, cdiv, check_use_alibi, + is_pin_memory_available) from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata from vllm.v1.attention.backends.utils import CommonAttentionMetadata from vllm.v1.core.encoder_cache_manager import compute_encoder_budget @@ -49,6 +51,7 @@ from vllm.v1.spec_decode.metadata import SpecDecodeMetadata from vllm.v1.spec_decode.ngram_proposer import NgramProposer from vllm.v1.spec_decode.utils import is_spec_decode_supported from vllm.v1.utils import bind_kv_cache +from vllm.v1.worker.block_table import BlockTable from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin @@ -100,59 +103,17 @@ class GPUModelRunner(LoRAModelRunnerMixin): self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[ cache_config.cache_dtype] - # NOTE(woosuk): sliding_window is None for models with interleaved - # attention. Use interleaved_sliding_window instead. - self.sliding_window = model_config.get_sliding_window() - self.interleaved_sliding_window = getattr( - model_config.hf_text_config, "interleaved_sliding_window", None) - self.window_size = (self.sliding_window - or self.interleaved_sliding_window) - self.is_multimodal_model = model_config.is_multimodal_model - self.block_size = cache_config.block_size self.max_model_len = model_config.max_model_len - self.max_num_blocks_per_req = cdiv(self.max_model_len, self.block_size) self.max_num_tokens = scheduler_config.max_num_batched_tokens self.max_num_reqs = scheduler_config.max_num_seqs # Model-related. - self.num_attn_layers = model_config.get_num_layers_by_block_type( - parallel_config, LayerBlockType.attention) self.num_query_heads = model_config.get_num_attention_heads( parallel_config) - self.num_kv_heads = model_config.get_num_kv_heads(parallel_config) - self.head_size = model_config.get_head_size() self.hidden_size = model_config.get_hidden_size() self.attention_chunk_size = model_config.attention_chunk_size - self.attn_backend = get_attn_backend( - self.head_size, - self.dtype, - self.kv_cache_dtype, - self.block_size, - self.model_config.is_attention_free, - use_mla=self.model_config.use_mla, - ) - if self.attn_backend is None: - error_msg = ( - f"Error with get_att_backend: {self.head_size=}, " - f"{self.dtype=}, {self.kv_cache_dtype=}, {self.block_size=}, " - f"{self.model_config.is_attention_free=}, " - f"{self.model_config.use_mla=}") - logger.error(error_msg) - raise NotImplementedError( - "Non-Attention backend is not supported by V1 GPUModelRunner.") - - if self.vllm_config.compilation_config.full_cuda_graph: - attn_backend_name = self.attn_backend.__name__ - flash_attn_version = get_flash_attn_version() - if attn_backend_name != "FlashAttentionBackend" or \ - flash_attn_version != 3: - raise ValueError( - f"full_cuda_graph is only supported with " - f"FA3. Current attention backend is {attn_backend_name}, " - f"FlashAttention version is {flash_attn_version}.") - self.cascade_attn_enabled = not self.model_config.disable_cascade_attn # Multi-modal data support @@ -174,8 +135,10 @@ class GPUModelRunner(LoRAModelRunnerMixin): # self.model: nn.Module # Set after load_model # Initialize in initialize_kv_cache self.kv_caches: list[torch.Tensor] = [] + self.attn_metadata_builders: list[AttentionMetadataBuilder] = [] + self.attn_backends: list[type[AttentionBackend]] = [] # self.kv_cache_config: KVCacheConfig - # self.attn_metadata_builder: type[AttentionMetadataBuilder] + # self.input_batch: InputBatch # Persistent batch. # req_id -> (input_id -> encoder_output) self.encoder_cache: dict[str, dict[int, torch.Tensor]] = {} @@ -200,16 +163,6 @@ class GPUModelRunner(LoRAModelRunnerMixin): # Request states. self.requests: dict[str, CachedRequestState] = {} - # Persistent batch. - self.input_batch = InputBatch( - max_num_reqs=self.max_num_reqs, - max_model_len=self.max_model_len, - max_num_blocks_per_req=self.max_num_blocks_per_req, - max_num_batched_tokens=self.max_num_tokens, - device=self.device, - pin_memory=self.pin_memory, - vocab_size=model_config.get_vocab_size(), - ) self.use_cuda_graph = (self.vllm_config.compilation_config.level == CompilationLevel.PIECEWISE @@ -304,6 +257,31 @@ class GPUModelRunner(LoRAModelRunnerMixin): pin_memory=self.pin_memory) self.seq_lens_np = self.seq_lens_cpu.numpy() + def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> bool: + """ + Update the order of requests in the batch based on the attention + backend's needs. For example, some attention backends (namely MLA) may + want to separate requests based on if the attention computation will be + compute-bound or memory-bound. + + Args: + scheduler_output: The scheduler output. + + Returns: + True if the batch was reordered, False otherwise. + """ + batch_reordered = self.attn_metadata_builders[0].reorder_batch( + self.input_batch, scheduler_output) + + # For models with multiple KV cache groups, the groups should agree on + # the same order of requests. We ensure this by only allowing the first + # group to reorder the batch and asserting that all other groups do not + # reorder the batch. + for i in range(1, len(self.kv_cache_config.kv_cache_groups)): + assert not self.attn_metadata_builders[i].reorder_batch( + self.input_batch, scheduler_output) + return batch_reordered + def _update_states(self, scheduler_output: "SchedulerOutput") -> None: """Update the cached states and the persistent batch with the scheduler output. @@ -440,7 +418,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): # Update the block IDs. if not req_data.resumed_from_preemption: # Append the new blocks to the existing block IDs. - req_state.block_ids.extend(req_data.new_block_ids) + for i in range(len(self.kv_cache_config.kv_cache_groups)): + req_state.block_ids[i].extend(req_data.new_block_ids[i]) else: # The request is resumed from preemption. # Replace the existing block IDs with the new ones. @@ -498,11 +477,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): if removed_req_indices: self.input_batch.condense(removed_req_indices) - # Some attention backends (namely MLA) may want to separate requests - # based on if the attention computation will be compute-bound or - # memory-bound. This gives them a hook to do that. - batch_reordered = self.attn_metadata_builder.reorder_batch( - self.input_batch, scheduler_output) + batch_reordered = self._may_reorder_batch(scheduler_output) if batch_changed or batch_reordered: self.input_batch.refresh_sampling_metadata() @@ -570,21 +545,29 @@ class GPUModelRunner(LoRAModelRunnerMixin): torch.from_numpy(token_indices), out=self.input_ids_cpu[:total_num_scheduled_tokens]) - # Calculate the slot mapping. - # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] - # -> [0, 0, K, K, K + 1, K + 1, K + 2, 2 * K, 2 * K, 2 * K + 1] - # where K is the max_num_blocks_per_req and the block size is 2. - # NOTE(woosuk): We can't simply use `token_indices // block_size` here - # because M (max_model_len) is not necessarily divisible by block_size. - block_table_indices = (req_indices * self.max_num_blocks_per_req + - positions_np // self.block_size) - block_table_cpu = self.input_batch.block_table.get_cpu_tensor() - block_numbers = block_table_cpu.flatten()[block_table_indices].numpy() - block_offsets = positions_np % self.block_size - np.add(block_numbers * self.block_size, - block_offsets, - out=self.input_batch.block_table. - slot_mapping_np[:total_num_scheduled_tokens]) + # Calculate the slot mapping for each KV cache group. + for kv_cache_group_id, kv_cache_group_spec in enumerate( + self.kv_cache_config.kv_cache_groups): + block_size = kv_cache_group_spec.kv_cache_spec.block_size + block_table: BlockTable = self.input_batch.block_table[ + kv_cache_group_id] + # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] + # -> [0, 0, K, K, K + 1, K + 1, K + 2, 2 * K, 2 * K, 2 * K + 1] + # where K is the max_num_blocks_per_req and the block size is 2. + # NOTE(woosuk): We can't simply use `token_indices // block_size` + # here because M (max_model_len) is not necessarily divisible by + # block_size. + block_table_indices = ( + req_indices * block_table.max_num_blocks_per_req + + positions_np // block_size) + block_table_cpu = block_table.get_cpu_tensor() + block_numbers = block_table_cpu.flatten( + )[block_table_indices].numpy() + block_offsets = positions_np % block_size + np.add( + block_numbers * block_size, + block_offsets, + out=block_table.slot_mapping_np[:total_num_scheduled_tokens]) # Prepare the attention metadata. self.query_start_loc_np[0] = 0 @@ -626,10 +609,6 @@ class GPUModelRunner(LoRAModelRunnerMixin): attn_metadata: dict[str, FlashAttentionMetadata] = {} # Prepare the attention metadata for each KV cache group and make layers # in the same group share the same metadata. - # NOTE(Chen): there is exactly one KV cache group that contains all - # attetnion layers in the model for now, so the current logic for - # getting attn_metadata is not related to kv_cache_group information. - # Will extend this part to support multiple KV cache groups later. for kv_cache_group_id, kv_cache_group_spec in enumerate( self.kv_cache_config.kv_cache_groups): @@ -638,15 +617,19 @@ class GPUModelRunner(LoRAModelRunnerMixin): if self.cascade_attn_enabled: common_prefix_len = self._compute_cascade_attn_prefix_len( num_scheduled_tokens, - scheduler_output.num_common_prefix_blocks, + scheduler_output. + num_common_prefix_blocks[kv_cache_group_id], + kv_cache_group_spec.kv_cache_spec, + self.attn_metadata_builders[kv_cache_group_id], ) - attn_metadata_i = self.attn_metadata_builder.build( - num_reqs=num_reqs, - num_actual_tokens=total_num_scheduled_tokens, - max_query_len=max_num_scheduled_tokens, - common_prefix_len=common_prefix_len, - common_attn_metadata=common_attn_metadata) + attn_metadata_i = ( + self.attn_metadata_builders[kv_cache_group_id].build( + num_reqs=num_reqs, + num_actual_tokens=total_num_scheduled_tokens, + max_query_len=max_num_scheduled_tokens, + common_prefix_len=common_prefix_len, + common_attn_metadata=common_attn_metadata)) for layer_name in kv_cache_group_spec.layer_names: attn_metadata[layer_name] = attn_metadata_i @@ -684,6 +667,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): self, num_scheduled_tokens: np.ndarray, num_common_prefix_blocks: int, + kv_cache_spec: KVCacheSpec, + attn_metadata_builder: AttentionMetadataBuilder, ) -> int: """Compute the length of the common prefix for cascade attention. @@ -702,7 +687,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): Returns: int: Length of common prefix in tokens. """ - common_prefix_len = num_common_prefix_blocks * self.block_size + common_prefix_len = num_common_prefix_blocks * kv_cache_spec.block_size if common_prefix_len == 0: # Common case. return 0 @@ -751,15 +736,19 @@ class GPUModelRunner(LoRAModelRunnerMixin): common_prefix_len, self.input_batch.num_computed_tokens_cpu[:num_reqs].min()) # common_prefix_len should be a multiple of the block size. - common_prefix_len = (common_prefix_len // self.block_size * - self.block_size) - use_cascade = self.attn_metadata_builder.use_cascade_attention( + common_prefix_len = (common_prefix_len // kv_cache_spec.block_size * + kv_cache_spec.block_size) + use_sliding_window = (isinstance(kv_cache_spec, SlidingWindowSpec) or + (isinstance(kv_cache_spec, FullAttentionSpec) + and kv_cache_spec.sliding_window is not None)) + assert isinstance(kv_cache_spec, AttentionSpec) + use_cascade = attn_metadata_builder.use_cascade_attention( common_prefix_len=common_prefix_len, query_lens=num_scheduled_tokens, num_query_heads=self.num_query_heads, - num_kv_heads=self.num_kv_heads, + num_kv_heads=kv_cache_spec.num_kv_heads, use_alibi=self.use_alibi, - use_sliding_window=self.window_size is not None, + use_sliding_window=use_sliding_window, num_sms=self.num_sms, ) return common_prefix_len if use_cascade else 0 @@ -1577,7 +1566,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): dtype=np.int32) if skip_attn: - attn_metadata = None + attn_metadata: Optional[dict[str, FlashAttentionMetadata]] = None else: query_start_loc = self.query_start_loc[:num_reqs + 1] seq_lens = self.seq_lens[:num_reqs] @@ -1585,13 +1574,19 @@ class GPUModelRunner(LoRAModelRunnerMixin): common_attn_metadata = CommonAttentionMetadata( query_start_loc=query_start_loc, seq_lens=seq_lens) - attn_metadata = self.attn_metadata_builder.build( - num_reqs=num_tokens, - num_actual_tokens=num_tokens, - max_query_len=num_tokens, - common_prefix_len=0, - common_attn_metadata=common_attn_metadata, - ) + attn_metadata = {} + for kv_cache_group_id, kv_cache_group_spec in enumerate( + self.kv_cache_config.kv_cache_groups): + attn_metadata_i = ( + self.attn_metadata_builders[kv_cache_group_id].build( + num_reqs=num_tokens, + num_actual_tokens=num_tokens, + max_query_len=num_tokens, + common_prefix_len=0, + common_attn_metadata=common_attn_metadata, + )) + for layer_name in kv_cache_group_spec.layer_names: + attn_metadata[layer_name] = attn_metadata_i with self.maybe_dummy_run_with_lora(self.lora_config, num_scheduled_tokens): @@ -1822,6 +1817,56 @@ class GPUModelRunner(LoRAModelRunnerMixin): logger.info("Graph capturing finished in %.0f secs, took %.2f GiB", elapsed_time, cuda_graph_size / (1 << 30)) + def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None: + """ + Initialize the attention backends and attention metadata builders. + """ + assert len(self.attn_backends) == 0 and len( + self.attn_metadata_builders + ) == 0, "Attention backends are already initialized" + for i, kv_cache_group_spec in enumerate( + kv_cache_config.kv_cache_groups): + kv_cache_spec = kv_cache_group_spec.kv_cache_spec + if not isinstance(kv_cache_spec, AttentionSpec): + raise NotImplementedError( + "Only AttentionSpec is supported for now.") + attn_backend_i = get_attn_backend( + kv_cache_spec.head_size, + self.dtype, + kv_cache_spec.dtype, + kv_cache_spec.block_size, + self.model_config.is_attention_free, + use_mla=kv_cache_spec.use_mla, + ) + if attn_backend_i is None: + error_msg = ( + f"Error with get_attn_backend: {kv_cache_spec.head_size=}, " + f"{self.dtype=}, {kv_cache_spec.dtype=}, " + f"{kv_cache_spec.block_size=}, " + f"{self.model_config.is_attention_free=}, " + f"{kv_cache_spec.use_mla=}") + logger.error(error_msg) + raise NotImplementedError( + "Non-Attention backend is not supported by V1 " + "GPUModelRunner.") + + if self.vllm_config.compilation_config.full_cuda_graph: + attn_backend_name = attn_backend_i.__name__ + flash_attn_version = get_flash_attn_version() + if attn_backend_name != "FlashAttentionBackend" or \ + flash_attn_version != 3: + raise ValueError( + f"full_cuda_graph is only supported with " + f"FA3. Current attention backend is " + f"{attn_backend_name}, FlashAttention version is " + f"{flash_attn_version}.") + + block_table_i = self.input_batch.block_table[i] + attn_metadata_builder_i = attn_backend_i.get_builder_cls()( + weakref.proxy(self), kv_cache_spec, block_table_i) + self.attn_backends.append(attn_backend_i) + self.attn_metadata_builders.append(attn_metadata_builder_i) + def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: """ Initialize KV cache based on `kv_cache_config`. @@ -1829,15 +1874,21 @@ class GPUModelRunner(LoRAModelRunnerMixin): kv_cache_config: Configuration for the KV cache, including the KV cache size of each layer """ - if len(kv_cache_config.kv_cache_groups) > 1: - raise NotImplementedError( - "Hybrid models with more than one KV cache type are not " - "supported yet.") self.kv_cache_config = kv_cache_config + self.input_batch = InputBatch( + max_num_reqs=self.max_num_reqs, + max_model_len=self.max_model_len, + max_num_batched_tokens=self.max_num_tokens, + device=self.device, + pin_memory=self.pin_memory, + vocab_size=self.model_config.get_vocab_size(), + kv_cache_config=kv_cache_config, + ) + self.initialize_attn_backend(kv_cache_config) kv_caches: dict[str, torch.Tensor] = {} - for kv_cache_group in kv_cache_config.kv_cache_groups: + for i, kv_cache_group in enumerate(kv_cache_config.kv_cache_groups): kv_cache_spec = kv_cache_group.kv_cache_spec for layer_name in kv_cache_group.layer_names: tensor_config = kv_cache_config.tensors[layer_name] @@ -1852,7 +1903,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): # the min of all `num_blocks`. Verify it here. assert num_blocks >= kv_cache_config.num_blocks if isinstance(kv_cache_spec, AttentionSpec): - kv_cache_shape = self.attn_backend.get_kv_cache_shape( + kv_cache_shape = self.attn_backends[i].get_kv_cache_shape( num_blocks, kv_cache_spec.block_size, kv_cache_spec.num_kv_heads, kv_cache_spec.head_size) dtype = kv_cache_spec.dtype @@ -1872,11 +1923,6 @@ class GPUModelRunner(LoRAModelRunnerMixin): if has_kv_transfer_group(): get_kv_transfer_group().register_kv_caches(kv_caches) - self.attn_metadata_builder = self.attn_backend.get_builder_cls()( - weakref.proxy(self), - kv_cache_config.kv_cache_groups[0].kv_cache_spec, - self.input_batch.block_table) - def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: """ Generates the KVCacheSpec by parsing the kv cache format from each diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index b4daf5a346788..2da99696445ee 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -171,19 +171,10 @@ class TPUModelRunner(LoRAModelRunnerMixin): self.kv_caches: list[torch.Tensor] = [] # req_id -> (input_id -> encoder_output) self.encoder_cache: dict[str, dict[int, torch.Tensor]] = {} + # self.input_batch: InputBatch # Persistent batch. # Request states. self.requests: dict[str, CachedRequestState] = {} - # Persistent batch. - self.input_batch = InputBatch( - max_num_reqs=self.max_num_reqs, - max_model_len=self.max_model_len, - max_num_blocks_per_req=self.max_num_blocks_per_req, - max_num_batched_tokens=self.max_num_tokens, - device=self.device, - pin_memory=self.pin_memory, - vocab_size=self.vocab_size, - ) # Cached torch/numpy tensor # The pytorch tensor and numpy array share the same buffer. @@ -199,7 +190,7 @@ class TPUModelRunner(LoRAModelRunnerMixin): self.block_table_cpu = torch.zeros( (self.max_num_reqs, self.max_num_blocks_per_req), - dtype=self.input_batch.block_table.get_cpu_tensor().dtype, + dtype=torch.int32, device="cpu") self.query_start_loc_cpu = torch.zeros(self.max_num_tokens + 1, @@ -524,12 +515,12 @@ class TPUModelRunner(LoRAModelRunnerMixin): # NOTE(woosuk): We use torch.index_select instead of np.take here # because torch.index_select is much faster than np.take for large # tensors. - block_table_cpu = self.input_batch.block_table.get_cpu_tensor() + block_table_cpu = self.input_batch.block_table[0].get_cpu_tensor() block_numbers = block_table_cpu.flatten()[block_table_indices].numpy() block_offsets = positions_np % self.block_size np.add(block_numbers * self.block_size, block_offsets, - out=self.input_batch.block_table. + out=self.input_batch.block_table[0]. slot_mapping_np[:total_num_scheduled_tokens]) # Prepare the attention metadata. @@ -554,15 +545,15 @@ class TPUModelRunner(LoRAModelRunnerMixin): self.position_ids = self.positions_cpu[: padded_total_num_scheduled_tokens].to( self.device) - self.input_batch.block_table.slot_mapping_cpu[ + self.input_batch.block_table[0].slot_mapping_cpu[ total_num_scheduled_tokens:] = _PAD_SLOT_ID slot_mapping = ( - self.input_batch.block_table. + self.input_batch.block_table[0]. slot_mapping_cpu[:padded_total_num_scheduled_tokens].to( self.device)) block_tables = self.block_table_cpu[:self.max_num_reqs] block_tables[:num_reqs, :self.max_num_blocks_per_req] = ( - self.input_batch.block_table.get_cpu_tensor()[:num_reqs]) + self.input_batch.block_table[0].get_cpu_tensor()[:num_reqs]) block_tables = block_tables.to(self.device) query_start_loc = self.query_start_loc_cpu[:self.max_num_reqs + 1].to( self.device) @@ -1263,6 +1254,18 @@ class TPUModelRunner(LoRAModelRunnerMixin): "Hybrid models with more than one KV cache type are not " "supported yet.") + self.input_batch = InputBatch( + max_num_reqs=self.max_num_reqs, + max_model_len=self.max_model_len, + max_num_batched_tokens=self.max_num_tokens, + device=self.device, + pin_memory=self.pin_memory, + vocab_size=self.model_config.get_vocab_size(), + kv_cache_config=kv_cache_config, + ) + assert self.block_table_cpu.dtype == self.input_batch.block_table[ + 0].get_cpu_tensor().dtype + kv_caches: dict[str, torch.Tensor] = {} for kv_cache_group in kv_cache_config.kv_cache_groups: