diff --git a/tests/v1/worker/test_gpu_input_batch.py b/tests/v1/worker/test_gpu_input_batch.py index 915ec2914a82..7b1359c8576f 100644 --- a/tests/v1/worker/test_gpu_input_batch.py +++ b/tests/v1/worker/test_gpu_input_batch.py @@ -221,6 +221,7 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int): max_num_reqs=batch_size, max_model_len=1024, max_num_blocks_per_req=10, + max_num_batched_tokens=1024, device=torch.device(device), pin_memory=is_pin_memory_available(), vocab_size=1024, @@ -310,6 +311,7 @@ def test_swap_states_in_input_batch(device: str, batch_size: int, max_num_reqs=batch_size, max_model_len=1024, max_num_blocks_per_req=10, + max_num_batched_tokens=1024, device=torch.device(device), pin_memory=is_pin_memory_available(), vocab_size=1024, @@ -318,6 +320,7 @@ def test_swap_states_in_input_batch(device: str, batch_size: int, max_num_reqs=batch_size, max_model_len=1024, max_num_blocks_per_req=10, + max_num_batched_tokens=1024, device=torch.device(device), pin_memory=is_pin_memory_available(), vocab_size=1024, diff --git a/tests/v1/worker/test_gpu_model_runner.py b/tests/v1/worker/test_gpu_model_runner.py index 68e34cfacc58..725747294fd8 100644 --- a/tests/v1/worker/test_gpu_model_runner.py +++ b/tests/v1/worker/test_gpu_model_runner.py @@ -1,14 +1,31 @@ # SPDX-License-Identifier: Apache-2.0 +import weakref + import pytest +import torch from vllm.config import CacheConfig, ModelConfig, SchedulerConfig, VllmConfig from vllm.sampling_params import SamplingParams from vllm.v1.core.sched.output import (CachedRequestData, NewRequestData, SchedulerOutput) +from vllm.v1.kv_cache_interface import FullAttentionSpec from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.worker.gpu_model_runner import GPUModelRunner +def initialize_kv_cache(runner: GPUModelRunner): + """ + Only perform necessary steps in GPUModelRunner.initialize_kv_cache() + """ + kv_cache_spec = FullAttentionSpec(block_size=16, + num_kv_heads=1, + head_size=64, + dtype=torch.float16, + use_mla=False) + runner.attn_metadata_builder = runner.attn_backend.get_builder_cls()( + weakref.proxy(runner), kv_cache_spec, runner.input_batch.block_table) + + @pytest.fixture def model_runner(): scheduler_config = SchedulerConfig( @@ -38,7 +55,9 @@ def model_runner(): ) device = "cuda" - return GPUModelRunner(vllm_config, device) + runner = GPUModelRunner(vllm_config, device) + initialize_kv_cache(runner) + return runner def _schedule_new_request(*req_ids: str) -> SchedulerOutput: diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 605dff3749fb..9ed3dec7f269 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -19,6 +19,8 @@ from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.utils import cdiv from vllm.v1.attention.backends.utils import CommonAttentionMetadata +from vllm.v1.kv_cache_interface import AttentionSpec +from vllm.v1.worker.block_table import BlockTable if TYPE_CHECKING: from vllm.v1.core.sched.output import SchedulerOutput @@ -167,7 +169,7 @@ def make_local_attention_virtual_batches( query_start_loc_np: np.ndarray, seq_lens_np: np.ndarray, block_table: torch.Tensor, - page_size: int = 0, + block_size: int = 0, ) -> tuple[np.ndarray, np.ndarray, np.ndarray, torch.Tensor]: q_seqlens = query_start_loc_np[1:] - query_start_loc_np[:-1] actual_batch_size = seq_lens_np.shape[0] @@ -238,14 +240,14 @@ def make_local_attention_virtual_batches( # For the example the local attention blocks start at: # _b0_ _____b1_____ _b2_ # k_seqstarts_absolute = [0, 4, 4, 8, 12, 16, 4, 8] - block_starts = k_seqstarts_absolute // page_size - assert attn_chunk_size % page_size == 0, \ + block_starts = k_seqstarts_absolute // block_size + assert attn_chunk_size % block_size == 0, \ f"attn_chunk_size {attn_chunk_size} is not " \ - f"divisible by page_size {page_size}" - pages_per_local_batch = attn_chunk_size // page_size + f"divisible by block_size {block_size}" + pages_per_local_batch = attn_chunk_size // block_size # Create a block_table for the local attention blocks - # For out example if we have a block-table like (assuming page_size=2): + # For out example if we have a block-table like (assuming block_size=2): # block_table = [ # [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9], < batch 0 # [10, 11, 12, 13, 14, 15, 16, 17, 18, 19], < batch 1 @@ -289,7 +291,8 @@ def _get_sliding_window_configs( class FlashAttentionMetadataBuilder: - def __init__(self, runner: "GPUModelRunner"): + def __init__(self, runner: "GPUModelRunner", kv_cache_spec: AttentionSpec, + block_table: BlockTable): model_config = runner.model_config compilation_config = runner.vllm_config.compilation_config @@ -299,7 +302,9 @@ class FlashAttentionMetadataBuilder: self.num_heads_kv = model_config.get_num_kv_heads( runner.parallel_config) self.headdim = model_config.get_head_size() - self.page_size = self.runner.block_size + self.block_size = kv_cache_spec.block_size + self.kv_cache_spec = kv_cache_spec + self.block_table = block_table if get_flash_attn_version() == 3: self.aot_schedule = not compilation_config.full_cuda_graph @@ -323,9 +328,17 @@ class FlashAttentionMetadataBuilder: max_seq_len = self.runner.seq_lens_np[:num_reqs].max() query_start_loc = common_attn_metadata.query_start_loc seq_lens = common_attn_metadata.seq_lens - block_table = ( - self.runner.input_batch.block_table.get_device_tensor()[:num_reqs]) - slot_mapping = self.runner.slot_mapping[:num_actual_tokens] + block_table = self.block_table + block_table_tensor = block_table.get_device_tensor()[:num_reqs] + + block_table.slot_mapping[:num_actual_tokens].copy_( + block_table.slot_mapping_cpu[:num_actual_tokens], + non_blocking=True) + # Fill unused with -1. Needed for reshape_and_cache in full cuda graph + # mode. + block_table.slot_mapping[num_actual_tokens:].fill_(-1) + + slot_mapping = block_table.slot_mapping[:num_actual_tokens] if self.aot_sliding_window is None: self.aot_sliding_window = (-1, -1) @@ -354,7 +367,7 @@ class FlashAttentionMetadataBuilder: num_heads_q=self.num_heads_q, num_heads_kv=self.num_heads_kv, headdim=self.headdim, - page_size=self.page_size, + page_size=self.block_size, cu_seqlens_q=cu_query_lens, causal=causal, window_size=self.aot_sliding_window, @@ -365,12 +378,12 @@ class FlashAttentionMetadataBuilder: local_attn_metadata = None if self.runner.attention_chunk_size is not None: seqlens_q_local_np, virt_q_cu_seqlens_np, virt_k_seqlens_np, \ - virt_block_table = make_local_attention_virtual_batches( + virt_block_table_tensor = make_local_attention_virtual_batches( self.runner.attention_chunk_size, self.runner.query_start_loc_np[:num_reqs + 1], self.runner.seq_lens_np[:num_reqs], - block_table, - self.runner.block_size, + block_table_tensor, + self.block_size, ) local_query_start_loc = torch.from_numpy(virt_q_cu_seqlens_np).to( self.runner.device, non_blocking=True) @@ -389,7 +402,7 @@ class FlashAttentionMetadataBuilder: local_attn_metadata = FlashAttentionMetadata.LocalAttentionMetadata( local_query_start_loc=local_query_start_loc, local_seqused_k=local_seqused_k, - local_block_table=virt_block_table, + local_block_table=virt_block_table_tensor, local_max_query_len=local_max_query_len, local_max_seq_len=local_max_seq_len, local_scheduler_metadata=local_scheduler_metadata, @@ -440,7 +453,7 @@ class FlashAttentionMetadataBuilder: query_start_loc=query_start_loc, max_seq_len=max_seq_len, seq_lens=seq_lens, - block_table=block_table, + block_table=block_table_tensor, slot_mapping=slot_mapping, use_cascade=use_cascade, common_prefix_len=common_prefix_len, diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 0852e15f9c19..dcc33cffb1d7 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -19,6 +19,8 @@ from vllm.config import (VllmConfig, get_current_vllm_config, from vllm.logger import init_logger from vllm.v1.attention.backends.flash_attn import use_cascade_attention from vllm.v1.attention.backends.utils import CommonAttentionMetadata +from vllm.v1.kv_cache_interface import AttentionSpec +from vllm.v1.worker.block_table import BlockTable if TYPE_CHECKING: from vllm.v1.core.sched.output import SchedulerOutput @@ -202,7 +204,8 @@ class FlashInferMetadata: class FlashInferMetadataBuilder: - def __init__(self, runner: GPUModelRunner): + def __init__(self, runner: GPUModelRunner, kv_cache_spec: AttentionSpec, + block_table: BlockTable): self.runner = runner self._workspace_buffer = None self._prefill_wrapper = None # Wrapper for prefill/append @@ -213,6 +216,8 @@ class FlashInferMetadataBuilder: self.global_hyperparameters: Optional[PerLayerParameters] = None self.vllm_config = get_current_vllm_config() + self.kv_cache_spec = kv_cache_spec + self.block_table = block_table def reorder_batch(self, input_batch: InputBatch, scheduler_output: SchedulerOutput) -> bool: @@ -400,13 +405,12 @@ class FlashInferMetadataBuilder: assert self._num_decodes + self._num_prefills == num_reqs assert (self._num_decode_tokens + self._num_prefill_tokens == num_actual_tokens) - page_size = self.runner.block_size + page_size = self.kv_cache_spec.block_size device = self.runner.device qo_indptr = common_attn_metadata.query_start_loc seq_lens = common_attn_metadata.seq_lens - block_table = ( - self.runner.input_batch.block_table.get_device_tensor()[:num_reqs]) - slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to( + block_table_tensor = self.block_table.get_device_tensor()[:num_reqs] + slot_mapping = self.block_table.slot_mapping_cpu[:num_actual_tokens].to( self.runner.device, non_blocking=True).long() block_table_bounds = (seq_lens + page_size - 1) // page_size @@ -422,12 +426,13 @@ class FlashInferMetadataBuilder: shared_kv_page_indptr = torch.tensor([0, num_common_kv_blocks], dtype=torch.int32, device=device) - shared_kv_page_indices = block_table[0, :num_common_kv_blocks] + shared_kv_page_indices = block_table_tensor[ + 0, :num_common_kv_blocks] shared_kv_last_page_len = torch.tensor([page_size], dtype=torch.int32, device=device) # Remove the blocks of the shared prefix from all requests. - block_table = block_table[:, num_common_kv_blocks:] + block_table_tensor = block_table_tensor[:, num_common_kv_blocks:] block_table_bounds -= num_common_kv_blocks else: shared_qo_indptr = None @@ -435,11 +440,11 @@ class FlashInferMetadataBuilder: shared_kv_page_indices = None shared_kv_last_page_len = None - mask = (torch.arange(block_table.size(1), - dtype=block_table.dtype, - device=block_table.device).unsqueeze(0) + mask = (torch.arange(block_table_tensor.size(1), + dtype=block_table_tensor.dtype, + device=block_table_tensor.device).unsqueeze(0) < block_table_bounds.unsqueeze(1)) - paged_kv_indices = block_table[mask] + paged_kv_indices = block_table_tensor[mask] paged_kv_indptr = torch.cat([ torch.zeros(1, @@ -459,10 +464,10 @@ class FlashInferMetadataBuilder: paged_kv_indices=paged_kv_indices, paged_kv_last_page_len=paged_kv_last_page_len, num_qo_heads=self.runner.num_query_heads, - num_kv_heads=self.runner.num_kv_heads, - head_dim=self.runner.head_size, + num_kv_heads=self.kv_cache_spec.num_kv_heads, + head_dim=self.kv_cache_spec.head_size, page_size=page_size, - data_type=self.runner.kv_cache_dtype, + data_type=self.kv_cache_spec.dtype, q_data_type=self.runner.dtype, slot_mapping=slot_mapping, num_decodes=self._num_decodes, @@ -481,7 +486,7 @@ class FlashInferMetadataBuilder: return attn_metadata def use_cascade_attention(self, *args, **kwargs) -> bool: - if self.runner.kv_cache_dtype != self.runner.model_config.dtype: + if self.kv_cache_spec.dtype != self.runner.model_config.dtype: # TODO: The cascade wrapper currently does not support setting # kv cache dtype to something different from query dtype. return False diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 0c740fbcc6b7..69fc1ac69ab6 100644 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -207,6 +207,8 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.platforms import current_platform from vllm.utils import cdiv, round_down from vllm.v1.attention.backends.utils import CommonAttentionMetadata +from vllm.v1.kv_cache_interface import AttentionSpec +from vllm.v1.worker.block_table import BlockTable try: from vllm.vllm_flash_attn import flash_attn_varlen_func @@ -334,6 +336,8 @@ class MLACommonMetadataBuilder(Generic[M]): def __init__(self, runner: "GPUModelRunner", + kv_cache_spec: AttentionSpec, + block_table: BlockTable, metadata_cls: Optional[type[M]] = None): self.metadata_cls = metadata_cls \ if metadata_cls is not None else MLACommonMetadata @@ -346,10 +350,11 @@ class MLACommonMetadataBuilder(Generic[M]): runner.parallel_config) self.mla_dims = get_mla_dims(model_config) self.aot_schedule = is_vllm_fa and (get_flash_attn_version() == 3) + self.kv_cache_spec = kv_cache_spec # Dont try to access the runner on AMD if self.aot_schedule: - self.page_size = self.runner.block_size + self.page_size = self.kv_cache_spec.block_size if self.chunked_prefill_enabled: self.chunked_prefill_workspace_size = min( @@ -375,6 +380,7 @@ class MLACommonMetadataBuilder(Generic[M]): dtype=model_config.dtype, device=runner.device, ) + self.block_table = block_table def reorder_batch(self, input_batch: "InputBatch", scheduler_output: "SchedulerOutput") -> bool: @@ -436,9 +442,10 @@ class MLACommonMetadataBuilder(Generic[M]): return modified_batch - def _build_decode(self, block_table: torch.Tensor, seq_lens: torch.Tensor): + def _build_decode(self, block_table_tensor: torch.Tensor, + seq_lens: torch.Tensor): return MLACommonDecodeMetadata( - block_table=block_table, + block_table=block_table_tensor, seq_lens=seq_lens, ) @@ -451,9 +458,9 @@ class MLACommonMetadataBuilder(Generic[M]): # function. We should avoid GPU -> CPU sync as much as possible because # it blocks on all previous kernels. device = self.runner.device - block_table = ( - self.runner.input_batch.block_table.get_device_tensor()[:num_reqs]) - slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to( + block_table = self.block_table + block_table_tensor = block_table.get_device_tensor()[:num_reqs] + slot_mapping = block_table.slot_mapping_cpu[:num_actual_tokens].to( device, non_blocking=True).long() query_start_loc = common_attn_metadata.query_start_loc @@ -530,7 +537,7 @@ class MLACommonMetadataBuilder(Generic[M]): self.chunked_prefill_workspace_size prefill_metadata = MLACommonPrefillMetadata( - block_table=block_table[reqs_start:, ...], + block_table=block_table_tensor[reqs_start:, ...], query_start_loc=prefill_query_start_loc, max_query_len=max_query_len, chunked_context=chunked_context_metadata, @@ -539,7 +546,7 @@ class MLACommonMetadataBuilder(Generic[M]): decode_metadata = None if self._num_decodes > 0: decode_metadata = self._build_decode( - block_table=block_table[:self._num_decodes, ...], + block_table_tensor=block_table_tensor[:self._num_decodes, ...], seq_lens=seq_lens[:self._num_decodes], ) diff --git a/vllm/v1/attention/backends/mla/flashmla.py b/vllm/v1/attention/backends/mla/flashmla.py index 2f35f9b0a54f..e6594c6b6fa8 100644 --- a/vllm/v1/attention/backends/mla/flashmla.py +++ b/vllm/v1/attention/backends/mla/flashmla.py @@ -16,6 +16,8 @@ from vllm.v1.attention.backends.mla.common import (MLACommonBackend, MLACommonImpl, MLACommonMetadata, MLACommonMetadataBuilder) +from vllm.v1.kv_cache_interface import AttentionSpec +from vllm.v1.worker.block_table import BlockTable logger = init_logger(__name__) @@ -52,13 +54,14 @@ class FlashMLAMetadata(MLACommonMetadata[FlashMLADecodeMetadata]): class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]): - def __init__(self, runner): - super().__init__(runner) + def __init__(self, runner, kv_cache_spec: AttentionSpec, + block_table: BlockTable): + super().__init__(runner, kv_cache_spec, block_table) self.num_q_heads = self.runner.model_config.get_num_attention_heads( self.runner.parallel_config) - def _build_decode(self, block_table: torch.Tensor, + def _build_decode(self, block_table_tensor: torch.Tensor, seq_lens: torch.Tensor) -> FlashMLADecodeMetadata: tile_scheduler_metadata, num_splits = \ get_mla_metadata( @@ -68,7 +71,7 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]): ) return FlashMLADecodeMetadata( - block_table=block_table, + block_table=block_table_tensor, seq_lens=seq_lens, tile_scheduler_metadata=tile_scheduler_metadata, num_splits=num_splits, diff --git a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py index 37b72c08d52b..f46010d757af 100644 --- a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py +++ b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py @@ -14,6 +14,8 @@ from vllm.v1.attention.backends.mla.common import (MLACommonBackend, MLACommonImpl, MLACommonMetadata, MLACommonMetadataBuilder) +from vllm.v1.kv_cache_interface import AttentionSpec +from vllm.v1.worker.block_table import BlockTable # yapf: enable @@ -59,8 +61,9 @@ class AiterMLAMetadata(MLACommonMetadata[AiterMLADecodeMetadata]): class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]): - def __init__(self, runner): - super().__init__(runner) + def __init__(self, runner, kv_cache_spec: AttentionSpec, + block_table: BlockTable): + super().__init__(runner, kv_cache_spec, block_table) max_model_len = self.runner.model_config.max_model_len assert max_model_len == 32768,\ "AITER MLA requires max_model_len=32768" diff --git a/vllm/v1/worker/block_table.py b/vllm/v1/worker/block_table.py index 7d4082b73992..581d3d9bd11b 100644 --- a/vllm/v1/worker/block_table.py +++ b/vllm/v1/worker/block_table.py @@ -14,11 +14,13 @@ class BlockTable: self, max_num_reqs: int, max_num_blocks_per_req: int, + max_num_batched_tokens: int, pin_memory: bool, device: torch.device, ): self.max_num_reqs = max_num_reqs self.max_num_blocks_per_req = max_num_blocks_per_req + self.max_num_batched_tokens = max_num_batched_tokens self.pin_memory = pin_memory self.device = device @@ -36,6 +38,15 @@ class BlockTable: self.block_table_np = self.block_table_cpu.numpy() self.num_blocks_per_row = np.zeros(max_num_reqs, dtype=np.int32) + self.slot_mapping_cpu = torch.zeros(self.max_num_batched_tokens, + dtype=torch.int64, + device="cpu", + pin_memory=self.pin_memory) + self.slot_mapping_np = self.slot_mapping_cpu.numpy() + self.slot_mapping = torch.zeros(self.max_num_batched_tokens, + dtype=torch.int64, + device=self.device) + def append_row( self, block_ids: list[int], diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index c00424dfea73..871654fca366 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -59,6 +59,7 @@ class InputBatch: max_num_reqs: int, max_model_len: int, max_num_blocks_per_req: int, + max_num_batched_tokens: int, device: torch.device, pin_memory: bool, vocab_size: int, @@ -66,6 +67,7 @@ class InputBatch: self.max_num_reqs = max_num_reqs self.max_model_len = max_model_len self.max_num_blocks_per_req = max_num_blocks_per_req + self.max_num_batched_tokens = max_num_batched_tokens self.device = device self.pin_memory = pin_memory self.vocab_size = vocab_size @@ -100,6 +102,7 @@ class InputBatch: self.block_table = BlockTable( max_num_reqs=max_num_reqs, max_num_blocks_per_req=max_num_blocks_per_req, + max_num_batched_tokens=max_num_batched_tokens, pin_memory=pin_memory, device=device, ) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index bd8c87fd9efc..fdb1339cddca 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -150,8 +150,6 @@ class GPUModelRunner(LoRAModelRunnerMixin): f"FA3. Current attention backend is {attn_backend_name}, " f"FlashAttention version is {flash_attn_version}.") - self.attn_metadata_builder = self.attn_backend.get_builder_cls()( - weakref.proxy(self)) self.cascade_attn_enabled = not self.model_config.disable_cascade_attn # Multi-modal data support @@ -174,6 +172,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): # Initialize in initialize_kv_cache self.kv_caches: list[torch.Tensor] = [] # self.kv_cache_config: KVCacheConfig + # self.attn_metadata_builder: type[AttentionMetadataBuilder] # req_id -> (input_id -> encoder_output) self.encoder_cache: dict[str, dict[int, torch.Tensor]] = {} @@ -203,6 +202,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): max_num_reqs=self.max_num_reqs, max_model_len=self.max_model_len, max_num_blocks_per_req=self.max_num_blocks_per_req, + max_num_batched_tokens=self.max_num_tokens, device=self.device, pin_memory=self.pin_memory, vocab_size=model_config.get_vocab_size(), @@ -291,11 +291,6 @@ class GPUModelRunner(LoRAModelRunnerMixin): device="cpu", pin_memory=self.pin_memory) self.positions_np = self.positions_cpu.numpy() - self.slot_mapping_cpu = torch.zeros(self.max_num_tokens, - dtype=torch.int64, - device="cpu", - pin_memory=self.pin_memory) - self.slot_mapping_np = self.slot_mapping_cpu.numpy() self.query_start_loc_cpu = torch.zeros(self.max_num_reqs + 1, dtype=torch.int32, device="cpu", @@ -586,7 +581,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): block_offsets = positions_np % self.block_size np.add(block_numbers * self.block_size, block_offsets, - out=self.slot_mapping_np[:total_num_scheduled_tokens]) + out=self.input_batch.block_table. + slot_mapping_np[:total_num_scheduled_tokens]) # Prepare the attention metadata. self.query_start_loc_np[0] = 0 @@ -614,12 +610,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): self.query_start_loc_cpu[:num_reqs + 1], non_blocking=True) self.seq_lens[:num_reqs].copy_(self.seq_lens_cpu[:num_reqs], non_blocking=True) - self.slot_mapping[:total_num_scheduled_tokens].copy_( - self.slot_mapping_cpu[:total_num_scheduled_tokens], - non_blocking=True) # Fill unused with -1. Needed for reshape_and_cache - self.slot_mapping[total_num_scheduled_tokens:].fill_(-1) self.seq_lens[num_reqs:].fill_(0) self.query_start_loc[num_reqs + 1:].fill_(-1) @@ -1821,6 +1813,11 @@ class GPUModelRunner(LoRAModelRunnerMixin): self.vllm_config.compilation_config.static_forward_context, self.kv_caches) + self.attn_metadata_builder = self.attn_backend.get_builder_cls()( + weakref.proxy(self), + kv_cache_config.kv_cache_groups[0].kv_cache_spec, + self.input_batch.block_table) + def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: """ Generates the KVCacheSpec by parsing the kv cache format from each diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index be059c30435c..983f8707a245 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -179,6 +179,7 @@ class TPUModelRunner(LoRAModelRunnerMixin): max_num_reqs=self.max_num_reqs, max_model_len=self.max_model_len, max_num_blocks_per_req=self.max_num_blocks_per_req, + max_num_batched_tokens=self.max_num_tokens, device=self.device, pin_memory=self.pin_memory, vocab_size=self.vocab_size, @@ -197,10 +198,6 @@ class TPUModelRunner(LoRAModelRunnerMixin): device="cpu") self.positions_np = self.positions_cpu.numpy() - self.slot_mapping_cpu = torch.zeros(self.max_num_tokens, - dtype=torch.int64, - device="cpu") - self.slot_mapping_np = self.slot_mapping_cpu.numpy() self.block_table_cpu = torch.zeros( (self.max_num_reqs, self.max_num_blocks_per_req), dtype=self.input_batch.block_table.get_cpu_tensor().dtype, @@ -533,7 +530,8 @@ class TPUModelRunner(LoRAModelRunnerMixin): block_offsets = positions_np % self.block_size np.add(block_numbers * self.block_size, block_offsets, - out=self.slot_mapping_np[:total_num_scheduled_tokens]) + out=self.input_batch.block_table. + slot_mapping_cpu[:total_num_scheduled_tokens]) # Prepare the attention metadata. self.query_start_loc_np[0] = 0 @@ -557,10 +555,12 @@ class TPUModelRunner(LoRAModelRunnerMixin): self.position_ids = self.positions_cpu[: padded_total_num_scheduled_tokens].to( self.device) - self.slot_mapping_cpu[total_num_scheduled_tokens:] = _PAD_SLOT_ID - slot_mapping = self.slot_mapping_cpu[: - padded_total_num_scheduled_tokens].to( - self.device) + self.input_batch.block_table.slot_mapping_cpu[ + total_num_scheduled_tokens:] = _PAD_SLOT_ID + slot_mapping = ( + self.input_batch.block_table. + slot_mapping_cpu[:padded_total_num_scheduled_tokens].to( + self.device)) block_tables = self.block_table_cpu[:self.max_num_reqs] block_tables[:num_reqs, :self.max_num_blocks_per_req] = ( self.input_batch.block_table.get_cpu_tensor()[:num_reqs])