From 61bb55f3d54e58a7888730018c0a727e04f2f2c3 Mon Sep 17 00:00:00 2001 From: Alexander Matveev Date: Tue, 28 Jan 2025 20:55:45 +0000 Subject: [PATCH] Chunked prompt works! Signed-off-by: Alexander Matveev --- vllm/v1/core/scheduler.py | 7 +- vllm/v1/worker/gpu_model_runner.py | 165 +-------- vllm/v1/worker/model_runner_base.py | 167 ++++++++- vllm/v1/worker/tpu_model_runner.py | 533 ++++++++++++---------------- 4 files changed, 406 insertions(+), 466 deletions(-) diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index abd5285de6b16..20f54753b439e 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -213,11 +213,12 @@ class Scheduler: num_new_tokens = self.block_size computed_blocks.pop() + # TODO: Remove # If chunked prefill is not enabled, then breakout of the loop # when above budget. - if (not self.scheduler_config.chunked_prefill_enabled - and num_new_tokens > token_budget): - break + # if (not self.scheduler_config.chunked_prefill_enabled + # and num_new_tokens > token_budget): + # break num_new_tokens = min(num_new_tokens, token_budget) assert num_new_tokens > 0 diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 1f5b2f67079ca..6899caedafe1a 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -9,23 +9,18 @@ import torch.distributed from vllm.config import CompilationLevel, VllmConfig from vllm.distributed.parallel_state import graph_capture from vllm.forward_context import set_forward_context -from vllm.inputs import INPUT_REGISTRY from vllm.logger import init_logger from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding from vllm.model_executor.model_loader import get_model from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs from vllm.multimodal.utils import group_mm_inputs_by_modality -from vllm.sampling_params import SamplingType from vllm.utils import DeviceMemoryProfiler, cdiv from vllm.v1.attention.backends.flash_attn import (FlashAttentionBackend, FlashAttentionMetadata) -from vllm.v1.core.encoder_cache_manager import compute_encoder_budget -from vllm.v1.engine.mm_input_mapper import MMInputMapperClient from vllm.v1.kv_cache_interface import FullAttentionSpec, KVCacheConfig from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.utils import bind_kv_cache -from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch from vllm.v1.worker.model_runner_base import ExecutionMode, ModelRunnerBase if TYPE_CHECKING: @@ -43,41 +38,9 @@ class GPUModelRunner(ModelRunnerBase): ): super().__init__(vllm_config, device) - # Persistent batch. - self.input_batch = InputBatch( - max_num_reqs=self.max_num_reqs, - max_model_len=self.max_model_len, - max_num_blocks_per_req=self.max_num_blocks_per_req, - device=self.device, - pin_memory=self.pin_memory, - vocab_size=self.model_config.get_vocab_size(), - ) - - # Request states. - self.requests: Dict[str, CachedRequestState] = {} - # KV caches for forward pass self.kv_caches: List[torch.Tensor] = [] - # Multi-modal data support - self.input_registry = INPUT_REGISTRY - self.mm_registry = MULTIMODAL_REGISTRY - - # NOTE: Initialized input mapper is only used for processing dummy - # multimodal data into multimodal kwargs for GPU memory profiling. - self.mm_input_mapper_profiling = MMInputMapperClient(self.model_config) - self.mm_input_mapper_profiling.use_cache = False - - encoder_compute_budget, encoder_cache_size = compute_encoder_budget( - model_config=self.model_config, - scheduler_config=self.scheduler_config, - ) - self.max_num_encoder_input_tokens = encoder_compute_budget - self.encoder_cache_size = encoder_cache_size - - # req_id -> (input_id -> encoder_output) - self.encoder_cache: Dict[str, Dict[int, torch.Tensor]] = {} - self.use_cuda_graph = (self.vllm_config.compilation_config.level == CompilationLevel.PIECEWISE and not self.model_config.enforce_eager) @@ -160,132 +123,6 @@ class GPUModelRunner(ModelRunnerBase): pin_memory=self.pin_memory) self.seq_lens_np = self.seq_lens_cpu.numpy() - def _update_states(self, scheduler_output: "SchedulerOutput") -> None: - # Remove stopped requests from the cached states. - # Keep the states of the pre-empted requests. - for req_id in scheduler_output.finished_req_ids: - self.requests.pop(req_id, None) - self.encoder_cache.pop(req_id, None) - - # Free the cached encoder outputs. - for req_id, input_id in scheduler_output.free_encoder_input_ids: - encoder_outputs = self.encoder_cache.get(req_id) - if encoder_outputs is not None: - encoder_outputs.pop(input_id, None) - if not encoder_outputs: - self.encoder_cache.pop(req_id, None) - - # Remove the requests from the persistent batch. - stopped_req_ids = set().union( - scheduler_output.preempted_req_ids, - scheduler_output.finished_req_ids, - ) - removed_req_indices: List[int] = [] - for req_id in stopped_req_ids: - req_index = self.input_batch.remove_request(req_id) - if req_index is not None: - removed_req_indices.append(req_index) - - # Update the states of the running requests. - for req_data in scheduler_output.scheduled_running_reqs: - req_id = req_data.req_id - req_state = self.requests[req_id] - req_index = self.input_batch.req_id_to_index[req_id] - - # Update the num_computed_tokens. - req_state.num_computed_tokens = req_data.num_computed_tokens - self.input_batch.num_computed_tokens_cpu[req_index] = ( - req_data.num_computed_tokens) - - # Update the block table. - num_new_blocks = len(req_data.new_block_ids) - if num_new_blocks == 0: - continue - start_index = len(req_state.block_ids) - req_state.block_ids.extend(req_data.new_block_ids) - self.input_batch.block_table.append_row(req_index, start_index, - req_data.new_block_ids) - - req_ids_to_add: List[str] = [] - # Add new requests to the cached states. - for new_req_data in scheduler_output.scheduled_new_reqs: - req_id = new_req_data.req_id - sampling_params = new_req_data.sampling_params - if sampling_params.sampling_type == SamplingType.RANDOM_SEED: - generator = torch.Generator(device=self.device) - generator.manual_seed(sampling_params.seed) - else: - generator = None - - self.requests[req_id] = CachedRequestState( - req_id=req_id, - prompt_token_ids=new_req_data.prompt_token_ids, - prompt=new_req_data.prompt, - mm_inputs=new_req_data.mm_inputs, - mm_positions=new_req_data.mm_positions, - sampling_params=sampling_params, - generator=generator, - block_ids=new_req_data.block_ids, - num_computed_tokens=new_req_data.num_computed_tokens, - output_token_ids=[], - ) - - # Only relevant for models using M-RoPE (e.g, Qwen2-VL) - if self.model_config.uses_mrope: - image_grid_thw = [] - video_grid_thw = [] - for mm_input in self.requests[req_id].mm_inputs: - if mm_input.get("image_grid_thw") is not None: - image_grid_thw.extend( - mm_input["image_grid_thw"].tolist()) - if mm_input.get("video_grid_thw") is not None: - video_grid_thw.extend( - mm_input["video_grid_thw"].tolist()) - - hf_config = self.model_config.hf_config - - self.requests[req_id].mrope_positions, \ - self.requests[req_id].mrope_position_delta = \ - MRotaryEmbedding.get_input_positions_tensor( - self.requests[req_id].prompt_token_ids, - image_grid_thw=image_grid_thw, - video_grid_thw=video_grid_thw, - image_token_id=hf_config.image_token_id, - video_token_id=hf_config.video_token_id, - vision_start_token_id=hf_config.vision_start_token_id, - vision_end_token_id=hf_config.vision_end_token_id, - spatial_merge_size=hf_config.vision_config. - spatial_merge_size, - ) - - req_ids_to_add.append(req_id) - - # Update the cached states of the resumed requests. - for res_req_data in scheduler_output.scheduled_resumed_reqs: - req_id = res_req_data.req_id - req_state = self.requests[req_id] - - req_state.block_ids = res_req_data.block_ids - req_state.num_computed_tokens = res_req_data.num_computed_tokens - req_ids_to_add.append(req_id) - - # Add the new or resumed requests to the persistent batch. - # The smaller empty indices are filled first. - removed_req_indices = sorted(removed_req_indices, reverse=True) - for req_id in req_ids_to_add: - req_state = self.requests[req_id] - if removed_req_indices: - # Fill the empty index. - req_index = removed_req_indices.pop() - else: - # Append to the end. - req_index = None - self.input_batch.add_request(req_state, req_index) - - # Condense the batched states if there are empty indices. - if removed_req_indices: - self.input_batch.condense(removed_req_indices) - def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens assert total_num_scheduled_tokens > 0 @@ -665,7 +502,7 @@ class GPUModelRunner(ModelRunnerBase): ) -> ModelRunnerOutput: assert self.model is not None - self._update_states(scheduler_output) + self.update_states(scheduler_output) if self.is_multimodal_model: # Run the multimodal encoder if any. diff --git a/vllm/v1/worker/model_runner_base.py b/vllm/v1/worker/model_runner_base.py index f45b24def7843..4b2ef2aea2b85 100644 --- a/vllm/v1/worker/model_runner_base.py +++ b/vllm/v1/worker/model_runner_base.py @@ -1,5 +1,5 @@ import enum -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Dict, List, Optional import torch import torch.distributed @@ -8,11 +8,18 @@ import torch.nn as nn from vllm.attention.backends.abstract import AttentionType from vllm.attention.layer import Attention from vllm.config import VllmConfig +from vllm.inputs import INPUT_REGISTRY from vllm.logger import init_logger +from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.sampling_params import SamplingType from vllm.utils import LayerBlockType, cdiv, is_pin_memory_available +from vllm.v1.core.encoder_cache_manager import compute_encoder_budget +from vllm.v1.engine.mm_input_mapper import MMInputMapperClient from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, KVCacheSpec) from vllm.v1.outputs import ModelRunnerOutput +from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch if TYPE_CHECKING: from vllm.v1.core.scheduler import SchedulerOutput @@ -75,6 +82,164 @@ class ModelRunnerBase: self.model: Optional[nn.Module] = None + # Persistent batch. + self.input_batch = InputBatch( + max_num_reqs=self.max_num_reqs, + max_model_len=self.max_model_len, + max_num_blocks_per_req=self.max_num_blocks_per_req, + device=self.device, + pin_memory=self.pin_memory, + vocab_size=self.model_config.get_vocab_size(), + ) + + # Request states. + self.requests: Dict[str, CachedRequestState] = {} + + # Multi-modal data support + self.input_registry = INPUT_REGISTRY + self.mm_registry = MULTIMODAL_REGISTRY + + # NOTE: Initialized input mapper is only used for processing dummy + # multimodal data into multimodal kwargs for GPU memory profiling. + self.mm_input_mapper_profiling = MMInputMapperClient(self.model_config) + self.mm_input_mapper_profiling.use_cache = False + + encoder_compute_budget, encoder_cache_size = compute_encoder_budget( + model_config=self.model_config, + scheduler_config=self.scheduler_config, + ) + self.max_num_encoder_input_tokens = encoder_compute_budget + self.encoder_cache_size = encoder_cache_size + + # req_id -> (input_id -> encoder_output) + self.encoder_cache: Dict[str, Dict[int, torch.Tensor]] = {} + + def update_states(self, scheduler_output: "SchedulerOutput") -> None: + # Remove stopped requests from the cached states. + # Keep the states of the pre-empted requests. + for req_id in scheduler_output.finished_req_ids: + self.requests.pop(req_id, None) + self.encoder_cache.pop(req_id, None) + + # Free the cached encoder outputs. + for req_id, input_id in scheduler_output.free_encoder_input_ids: + encoder_outputs = self.encoder_cache.get(req_id) + if encoder_outputs is not None: + encoder_outputs.pop(input_id, None) + if not encoder_outputs: + self.encoder_cache.pop(req_id, None) + + # Remove the requests from the persistent batch. + stopped_req_ids = set().union( + scheduler_output.preempted_req_ids, + scheduler_output.finished_req_ids, + ) + removed_req_indices: List[int] = [] + for req_id in stopped_req_ids: + req_index = self.input_batch.remove_request(req_id) + if req_index is not None: + removed_req_indices.append(req_index) + + # Update the states of the running requests. + for req_data in scheduler_output.scheduled_running_reqs: + req_id = req_data.req_id + req_state = self.requests[req_id] + req_index = self.input_batch.req_id_to_index[req_id] + + # Update the num_computed_tokens. + req_state.num_computed_tokens = req_data.num_computed_tokens + self.input_batch.num_computed_tokens_cpu[req_index] = ( + req_data.num_computed_tokens) + + # Update the block table. + num_new_blocks = len(req_data.new_block_ids) + if num_new_blocks == 0: + continue + start_index = len(req_state.block_ids) + req_state.block_ids.extend(req_data.new_block_ids) + self.input_batch.block_table.append_row(req_index, start_index, + req_data.new_block_ids) + + req_ids_to_add: List[str] = [] + # Add new requests to the cached states. + for new_req_data in scheduler_output.scheduled_new_reqs: + req_id = new_req_data.req_id + sampling_params = new_req_data.sampling_params + if sampling_params.sampling_type == SamplingType.RANDOM_SEED: + generator = torch.Generator(device=self.device) + generator.manual_seed(sampling_params.seed) + else: + generator = None + + self.requests[req_id] = CachedRequestState( + req_id=req_id, + prompt_token_ids=new_req_data.prompt_token_ids, + prompt=new_req_data.prompt, + mm_inputs=new_req_data.mm_inputs, + mm_positions=new_req_data.mm_positions, + sampling_params=sampling_params, + generator=generator, + block_ids=new_req_data.block_ids, + num_computed_tokens=new_req_data.num_computed_tokens, + output_token_ids=[], + ) + + # Only relevant for models using M-RoPE (e.g, Qwen2-VL) + if self.model_config.uses_mrope: + image_grid_thw = [] + video_grid_thw = [] + for mm_input in self.requests[req_id].mm_inputs: + if mm_input.get("image_grid_thw") is not None: + image_grid_thw.extend( + mm_input["image_grid_thw"].tolist()) + if mm_input.get("video_grid_thw") is not None: + video_grid_thw.extend( + mm_input["video_grid_thw"].tolist()) + + hf_config = self.model_config.hf_config + + self.requests[req_id].mrope_positions, \ + self.requests[req_id].mrope_position_delta = \ + MRotaryEmbedding.get_input_positions_tensor( + self.requests[req_id].prompt_token_ids, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + image_token_id=hf_config.image_token_id, + video_token_id=hf_config.video_token_id, + vision_start_token_id=hf_config.vision_start_token_id, + vision_end_token_id=hf_config.vision_end_token_id, + spatial_merge_size=hf_config.vision_config. + spatial_merge_size, + ) + + req_ids_to_add.append(req_id) + + # Update the cached states of the resumed requests. + for res_req_data in scheduler_output.scheduled_resumed_reqs: + req_id = res_req_data.req_id + req_state = self.requests[req_id] + + req_state.block_ids = res_req_data.block_ids + req_state.num_computed_tokens = res_req_data.num_computed_tokens + req_ids_to_add.append(req_id) + + # Add the new or resumed requests to the persistent batch. + # The smaller empty indices are filled first. + removed_req_indices = sorted(removed_req_indices, reverse=True) + for req_id in req_ids_to_add: + req_state = self.requests[req_id] + if removed_req_indices: + # Fill the empty index. + req_index = removed_req_indices.pop() + else: + # Append to the end. + req_index = None + self.input_batch.add_request(req_state, req_index) + + # Condense the batched states if there are empty indices. + if removed_req_indices: + self.input_batch.condense(removed_req_indices) + def get_model(self) -> nn.Module: assert self.model is not None return self.model diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 7be9bd27aeb4f..cfe4792a95d19 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -15,7 +15,6 @@ from vllm.config import VllmConfig from vllm.forward_context import set_forward_context from vllm.logger import init_logger from vllm.model_executor.model_loader import get_model -from vllm.sampling_params import SamplingType from vllm.v1.attention.backends.pallas import (PallasAttentionBackend, PallasMetadata) from vllm.v1.kv_cache_interface import FullAttentionSpec, KVCacheConfig @@ -40,25 +39,24 @@ _MAX_NUM_SAMPLES = 128 @dataclass -class PrefillInputData: +class PromptInputData: - request_ids: List + req_ids: List prompt_lens: List - token_ids: List - position_ids: List + input_tokens: List + input_positions: List attn_metadata: List def zipped(self): - return zip(self.request_ids, self.prompt_lens, self.token_ids, - self.position_ids, self.attn_metadata) + return zip(self.req_ids, self.prompt_lens, self.input_tokens, + self.input_positions, self.attn_metadata) @dataclass class DecodeInputData: - - num_decodes: int - token_ids: Optional[torch.Tensor] = None - position_ids: Optional[torch.Tensor] = None + req_ids: List + input_tokens: Optional[torch.Tensor] = None + input_positions: Optional[torch.Tensor] = None attn_metadata: Optional[PallasMetadata] = None @@ -88,158 +86,105 @@ class TPUModelRunner(ModelRunnerBase): self.kv_caches: List[Tuple[torch.Tensor, torch.Tensor]] = [] # Used to initialize positions for the individual prefills - self.prefill_positions = torch.tensor(range(self.max_model_len), - device="cpu", - dtype=torch.int32).reshape( - 1, -1) + self.prefill_input_positions = torch.tensor(range(self.max_model_len), + device="cpu", + dtype=torch.int32).reshape( + 1, -1) - # Used to indicate how many prefills there are for each scheduler - # iteration - self.num_new_reqs: int = 0 - - def _update_states(self, scheduler_output: "SchedulerOutput") -> None: - # Remove stopped requests from the cached states. - # Keep the states of the pre-empted requests. - for req_id in scheduler_output.finished_req_ids: - self.requests.pop(req_id, None) - - # Remove the requests from the persistent batch. - stopped_req_ids = set().union( - scheduler_output.preempted_req_ids, - scheduler_output.finished_req_ids, - ) - removed_req_indices: List[int] = [] - for req_id in stopped_req_ids: - req_index = self.input_batch.remove_request(req_id) - if req_index is not None: - removed_req_indices.append(req_index) - - # Update the states of the running requests. - for req_data in scheduler_output.scheduled_running_reqs: - req_id = req_data.req_id - req_state = self.requests[req_id] - req_index = self.input_batch.req_id_to_index[req_id] - - # Update the num_computed_tokens. - req_state.num_computed_tokens = req_data.num_computed_tokens - self.input_batch.num_computed_tokens_cpu[req_index] = ( - req_data.num_computed_tokens) - - # Update the block table. - num_new_blocks = len(req_data.new_block_ids) - if num_new_blocks == 0: - continue - start_index = len(req_state.block_ids) - req_state.block_ids.extend(req_data.new_block_ids) - self.input_batch.block_table.append_row(req_index, start_index, - req_data.new_block_ids) - - req_ids_to_add: List[str] = [] - # Add new requests to the cached states. - for new_req_data in scheduler_output.scheduled_new_reqs: - req_id = new_req_data.req_id - sampling_params = new_req_data.sampling_params - if sampling_params.sampling_type == SamplingType.RANDOM_SEED: - generator = torch.Generator(device=self.device) - generator.manual_seed(sampling_params.seed) - else: - generator = None - - self.requests[req_id] = CachedRequestState( - req_id=req_id, - prompt_token_ids=new_req_data.prompt_token_ids, - prompt=new_req_data.prompt, - mm_inputs=new_req_data.mm_inputs, - mm_positions=new_req_data.mm_positions, - sampling_params=sampling_params, - generator=generator, - block_ids=new_req_data.block_ids, - num_computed_tokens=new_req_data.num_computed_tokens, - output_token_ids=[], - ) - req_ids_to_add.append(req_id) - - # Update the cached states of the resumed requests. - for res_req_data in scheduler_output.scheduled_resumed_reqs: - req_id = res_req_data.req_id - req_state = self.requests[req_id] - - req_state.block_ids = res_req_data.block_ids - req_state.num_computed_tokens = res_req_data.num_computed_tokens - req_ids_to_add.append(req_id) - - # For TPU, we keep all of the decode requests before the - # prefill requests in the batch sequence. - # 1. First condense, so all decodes move to start - # 2. Then add new prefills to the end of the batch - removed_req_indices = sorted(removed_req_indices, reverse=True) - if removed_req_indices: - self.input_batch.condense(removed_req_indices) - - for req_id in req_ids_to_add: - req_state = self.requests[req_id] - self.input_batch.add_request(req_state, None) # Append last - - self.num_new_reqs = len(req_ids_to_add) - - def _prepare_prefill_inputs( + def _prepare_prompt_inputs( self, - num_scheduled_tokens: List[int], - ) -> PrefillInputData: - # Each prefill run separately with shape [1, padded_prompt_len]. - # So we create lists that will be used in execute_model(). - - prefill_request_ids = [] - prefill_prompt_lens = [] - prefill_token_ids = [] - prefill_position_ids = [] - prefill_attn_metadata = [] - - # DECODES are the first num_decodes REQUESTS. - # PREFILLS are the next num_reqs - num_decodes REQUESTS. + scheduler_output: "SchedulerOutput", + ) -> PromptInputData: + total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens + assert total_num_scheduled_tokens > 0 num_reqs = self.input_batch.num_reqs - num_decodes = num_reqs - self.num_new_reqs - for idx in range(num_decodes, num_reqs): - req_id = self.input_batch.req_ids[idx] - prefill_request_ids.append(req_id) + assert num_reqs > 0 - prompt_len = num_scheduled_tokens[idx] - prefill_prompt_lens.append(prompt_len) + # OPTIMIZATION: Start copying the block table first. + # This way, we can overlap the copy with the following CPU operations. + self.input_batch.block_table.commit(num_reqs) - # STATIC SHAPE: prefills are padded to the next power of 2. + req_ids = [] + prompt_lens = [] + input_tokens_list = [] + input_positions_list = [] + attn_metadata_list = [] + for req_id in self.input_batch.req_ids[:num_reqs]: + assert req_id is not None + req_index = self.input_batch.req_id_to_index[req_id] + req_state = self.requests[req_id] + + num_scheduled_tokens = scheduler_output.num_scheduled_tokens[ + req_id] + num_computed_tokens = req_state.num_computed_tokens + num_prompt_tokens = len(req_state.prompt_token_ids) + + # Detect whether this is a prompt (can be full or chunked) + if num_computed_tokens >= num_prompt_tokens: + # This is a decode => Skip + continue + + # This is a prompt + req_ids.append(req_id) + + # Prompt len + prompt_len = num_scheduled_tokens + prompt_lens.append(prompt_len) padded_prompt_len = _get_padded_prefill_len(prompt_len) assert padded_prompt_len <= self.max_model_len - # TOKEN_IDS. - token_ids = torch.from_numpy(self.input_batch.token_ids_cpu[ - idx, :padded_prompt_len].reshape(1, -1)) - token_ids[:, prompt_len:] = 0 - prefill_token_ids.append(token_ids.to(self.device)) + # Seq len + seq_len = num_computed_tokens + prompt_len + padded_seq_len = num_computed_tokens + padded_prompt_len - # POSITIONS. - positions = self.prefill_positions[:, :padded_prompt_len].clone() - positions[:, prompt_len:] = 0 - prefill_position_ids.append(positions.to(self.device)) + # Input tokens + input_tokens = torch.from_numpy(self.input_batch.token_ids_cpu[ + req_index, num_computed_tokens:padded_seq_len].reshape(1, -1)) + input_tokens[:, prompt_len:] = 0 + input_tokens_list.append(input_tokens.to(self.device)) - # SLOT_MAPPING. - # The "slot" is the "physical index" of a token in the KV cache. - # Look up the block_idx in the block table (logical<>physical map) - # to compute this. + # Input positions + input_positions = self.prefill_input_positions[:, + num_computed_tokens: + padded_seq_len].clone( + ) + input_positions[:, prompt_len:] = 0 + input_positions_list.append(input_positions.to(self.device)) + + # Slot mapping block_table_cpu_tensor = \ self.input_batch.block_table.get_cpu_tensor() - - block_numbers = block_table_cpu_tensor[idx, positions // + block_numbers = block_table_cpu_tensor[req_index, + input_positions // self.block_size].reshape( 1, -1) - block_offsets = positions % self.block_size + block_offsets = input_positions % self.block_size slot_mapping = block_numbers * self.block_size + block_offsets - # Set an out of range value for the padding tokens so that they - # are ignored when inserting into the KV cache. slot_mapping[:, prompt_len:] = _PAD_SLOT_ID slot_mapping = slot_mapping.long() - prefill_attn_metadata.append( + # Block table + block_table = None + if num_computed_tokens > 0: + block_table = self.input_batch.block_table.get_device_tensor() + block_table = block_table[req_index].unsqueeze(0) + + # Context len + context_len = 0 + if num_computed_tokens > 0: + context_len = seq_len + context_lens = torch.tensor([context_len], + dtype=torch.int32, + device="cpu") + + # Effective query len + effective_query_lens = torch.tensor([prompt_len], + dtype=torch.int32, + device="cpu") + + # Attn metadata + attn_metadata_list.append( PallasMetadata( num_prefills=1, num_prefill_tokens=0, # NOTE: This is not used. @@ -247,208 +192,200 @@ class TPUModelRunner(ModelRunnerBase): slot_mapping=slot_mapping.to(self.device), multi_modal_placeholder_index_maps=None, enable_kv_scales_calculation=True, - block_tables=None, - context_lens=None, - effective_query_lens=None, + block_tables=block_table, + context_lens=context_lens.to(self.device), + effective_query_lens=effective_query_lens.to(self.device), )) - return PrefillInputData( - request_ids=prefill_request_ids, - prompt_lens=prefill_prompt_lens, - token_ids=prefill_token_ids, - position_ids=prefill_position_ids, - attn_metadata=prefill_attn_metadata, + return PromptInputData( + req_ids=req_ids, + prompt_lens=prompt_lens, + input_tokens=input_tokens_list, + input_positions=input_positions_list, + attn_metadata=attn_metadata_list, ) - def _prepare_decode_inputs(self) -> DecodeInputData: - # Decodes run as one single padded batch with shape [batch, 1] - # - # We need to set _PAD_SLOT_ID for the padding tokens in the - # slot_mapping, such that the attention KV cache insertion - # logic knows to ignore those indices. Otherwise, the - # padding data can be dummy since we have a causal mask. - - # DECODES are the first num_decodes REQUESTS. - # PREFILLS are the next num_reqs - num_decodes REQUESTS. - num_reqs = self.input_batch.num_reqs - num_decodes = num_reqs - self.num_new_reqs - - if num_decodes == 0: - return DecodeInputData(num_decodes=0) - - # PAD FOR STATIC SHAPES. - padded_batch_size = _get_padded_batch_size(num_decodes) - - # POSITIONS. [batch, 1] - # We slice at the end, since we use the positions for gathering. - positions = torch.from_numpy( - self.input_batch.num_computed_tokens_cpu.reshape(-1, 1)) - index = positions.to(torch.int64) - index[num_decodes:] = 0 - positions = positions[:padded_batch_size] - positions[num_decodes:] = 0 - - # TOKEN_IDS. [batch, 1] - token_ids = torch.gather( - input=torch.from_numpy(self.input_batch.token_ids_cpu), - dim=1, - index=index, - )[:padded_batch_size].to(torch.int32) - token_ids[num_decodes:] = 0 - - # SLOT_MAPPING [batch, 1] - # The "slot" is the "physical index" of a token in the KV cache. - # Look up the block_idx in the block table (logical<>physical map) - # to compute this. - block_table_cpu_tensor = self.input_batch.block_table.get_cpu_tensor() - block_number = torch.gather(input=block_table_cpu_tensor, - dim=1, - index=(index // self.block_size)) - block_offsets = index % self.block_size - slot_mapping = block_number * self.block_size + block_offsets - # Set an out of range value for the padding tokens so that they - # are ignored when inserting into the KV cache. - slot_mapping[num_decodes:] = _PAD_SLOT_ID - slot_mapping = slot_mapping[:padded_batch_size] - slot_mapping = slot_mapping.long() - - # BLOCK_TABLE [batch, max_num_blocks_per_req] - block_table = block_table_cpu_tensor[:padded_batch_size] - - # CONTEXT_LENS [batch_size] - context_lens = (positions.reshape(-1) + 1) - context_lens[num_decodes:] = 0 - - # CPU<>TPU sync happens here. - return DecodeInputData(num_decodes=num_decodes, - token_ids=token_ids.to(self.device), - position_ids=positions.to(self.device), - attn_metadata=PallasMetadata( - num_prefills=0, - num_prefill_tokens=0, - num_decode_tokens=padded_batch_size, - slot_mapping=slot_mapping.to(self.device), - multi_modal_placeholder_index_maps=None, - enable_kv_scales_calculation=True, - block_tables=block_table.to(self.device), - context_lens=context_lens.to(self.device), - effective_query_lens=None, - )) - - def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): + def _prepare_decode_inputs( + self, + scheduler_output: "SchedulerOutput", + ) -> DecodeInputData: total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens assert total_num_scheduled_tokens > 0 num_reqs = self.input_batch.num_reqs assert num_reqs > 0 - num_decodes = num_reqs - self.num_new_reqs + block_table_cpu_tensor = self.input_batch.block_table.get_cpu_tensor() - # TODO: Resurrect - # OPTIMIZATION: Start copying the block table first. - # This way, we can overlap the copy with the following CPU operations. - # TODO: Verify this works with TPUs - # self.input_batch.block_table.commit(num_reqs) - - # Get the number of scheduled tokens for each request. - # TODO: The Python loop can be slow. Optimize. - num_scheduled_tokens = [] - max_num_scheduled_tokens = 0 - for idx, req_id in enumerate(self.input_batch.req_ids[:num_reqs]): + req_ids = [] + req_indices = [] + input_tokens = [] + input_positions = [] + slot_mapping = [] + context_lens = [] + for req_id in self.input_batch.req_ids[:num_reqs]: assert req_id is not None - num_tokens = scheduler_output.num_scheduled_tokens[req_id] - num_scheduled_tokens.append(num_tokens) - max_num_scheduled_tokens = max(max_num_scheduled_tokens, - num_tokens) + req_index = self.input_batch.req_id_to_index[req_id] + req_state = self.requests[req_id] - # NOTE: Assert that all the decodes are "decodes". - if idx < num_decodes: - assert num_tokens == 1 + num_scheduled_tokens = scheduler_output.num_scheduled_tokens[ + req_id] + num_computed_tokens = req_state.num_computed_tokens + num_prompt_tokens = len(req_state.prompt_token_ids) - assert max_num_scheduled_tokens > 0 + # Detect whether this is a decode + if num_computed_tokens < num_prompt_tokens: + # This is a prompt => Skip + continue - return ( - self._prepare_prefill_inputs(num_scheduled_tokens), - self._prepare_decode_inputs(), + # This is a decode + req_ids.append(req_id) + req_indices.append(req_index) + + # Seq len + seq_len = num_computed_tokens + num_scheduled_tokens + + # Sanity check decode + assert num_scheduled_tokens == 1 + assert seq_len == req_state.num_tokens + + # Input token + input_tokens.append([ + self.input_batch.token_ids_cpu[req_index, num_computed_tokens] + ]) + + # Position + input_positions.append([num_computed_tokens]) + + # Slot mapping + block_number = block_table_cpu_tensor[req_index, + num_computed_tokens // + self.block_size] + block_offset = num_computed_tokens % self.block_size + slot_id = block_number * self.block_size + block_offset + slot_mapping.append([slot_id]) + + # Context len + context_lens.append(seq_len) + + # Compute padding + batch_size = len(input_tokens) + padded_batch_size = _get_padded_batch_size(batch_size) + num_padding = padded_batch_size - batch_size + + # Add padding + input_tokens.extend([[0]] * num_padding) + input_positions.extend([[0]] * num_padding) + slot_mapping.extend([[_PAD_SLOT_ID]] * num_padding) + context_lens.extend([0] * num_padding) + req_indices.extend([0] * num_padding) + + # Create tensors + input_tokens_tensor = torch.tensor(input_tokens, + dtype=torch.int32, + device="cpu") + input_positions_tensor = torch.tensor(input_positions, + dtype=torch.int32, + device="cpu") + slot_mapping_tensor = torch.tensor(slot_mapping, + dtype=torch.int64, + device="cpu") + context_lens_tensor = torch.tensor(context_lens, + dtype=torch.int32, + device="cpu") + block_tables_tensor = block_table_cpu_tensor[req_indices] + + # Attn metadata + attn_metadata = PallasMetadata( + num_prefills=0, + num_prefill_tokens=0, + num_decode_tokens=padded_batch_size, + slot_mapping=slot_mapping_tensor.to(self.device), + multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=True, + block_tables=block_tables_tensor.to(self.device), + context_lens=context_lens_tensor.to(self.device), + effective_query_lens=None, ) + return DecodeInputData( + req_ids=req_ids, + input_tokens=input_tokens_tensor.to(self.device), + input_positions=input_positions_tensor.to(self.device), + attn_metadata=attn_metadata) + @torch.no_grad() def execute_model( self, scheduler_output: "SchedulerOutput", ) -> ModelRunnerOutput: - self._update_states(scheduler_output) + # Update cached state + self.update_states(scheduler_output) - # Prepare the decoder inputs. - prefill_data, decode_data = self._prepare_inputs(scheduler_output) + # Prepare inputs + prompt_data = self._prepare_prompt_inputs(scheduler_output) + decode_data = self._prepare_decode_inputs(scheduler_output) + # Init num_reqs = self.input_batch.num_reqs + assert num_reqs > 0 + sampled_token_ids_list = [0] * num_reqs - ######################### DECODES ######################### - # Decodes run as one single batch with [padded_batch, 1] - sampled_token_ids_list = [] - if decode_data.num_decodes > 0: - # FORWARD. + # Run decodes (a single batch) + if len(decode_data.req_ids) > 0: + # Forward with set_forward_context(decode_data.attn_metadata, self.vllm_config): assert self.model is not None - selected_token_ids = self.model(decode_data.token_ids, - decode_data.position_ids, + selected_token_ids = self.model(decode_data.input_tokens, + decode_data.input_positions, decode_data.attn_metadata, self.kv_caches) - # NOTE: TPU<>CPU sync happens here. - # We need to call .cpu() first to avoid recompilation. - token_ids = selected_token_ids.cpu()[:decode_data.num_decodes] - sampled_token_ids_list.extend(token_ids.tolist()) + # Transfer sampled tokens from TPU to CPU + selected_token_ids_list = selected_token_ids.cpu().tolist() - # UPDATE REQUEST STATE. - for i, req_id in enumerate( - self.input_batch.req_ids[:decode_data.num_decodes]): - assert req_id is not None + # Update cached state + for i, req_id in enumerate(decode_data.req_ids): + req_index = self.input_batch.req_id_to_index[req_id] req_state = self.requests[req_id] - assert scheduler_output.num_scheduled_tokens[req_id] == 1 seq_len = (req_state.num_computed_tokens + scheduler_output.num_scheduled_tokens[req_id]) - assert seq_len == req_state.num_tokens - token_id = sampled_token_ids_list[i] - self.input_batch.token_ids_cpu[i, seq_len] = token_id - self.input_batch.num_tokens[i] += 1 + token_id = selected_token_ids_list[i] + + self.input_batch.token_ids_cpu[req_index, seq_len] = token_id + self.input_batch.num_tokens[req_index] += 1 req_state.output_token_ids.append(token_id) - ######################### PREFILLS ######################### - # Prefills run separately with shape [1, padded_prefill_len], - # due to lack of variable length attention kernel so far. - for (req_id, prompt_len, token_ids, position_ids, - attn_metadata) in prefill_data.zipped(): - assert req_id is not None + sampled_token_ids_list[req_index] = token_id - # FORWARD. + # Run each prompt + for (req_id, prompt_len, input_tokens, input_positions, + attn_metadata) in prompt_data.zipped(): + assert req_id is not None + req_state = self.requests[req_id] + req_index = self.input_batch.req_id_to_index[req_id] + + # Forward with set_forward_context(attn_metadata, self.vllm_config): assert self.model is not None - selected_token_ids = self.model(token_ids, position_ids, + selected_token_ids = self.model(input_tokens, input_positions, attn_metadata, self.kv_caches) - # NOTE: TPU<>CPU sync happens here. - # We need to call .cpu() first to avoid recompilation. - token_id = selected_token_ids.cpu()[prompt_len - 1].item() - sampled_token_ids_list.append(token_id) - req_state = self.requests[req_id] - - assert req_state.num_computed_tokens == 0 seq_len = (req_state.num_computed_tokens + scheduler_output.num_scheduled_tokens[req_id]) - assert seq_len == req_state.num_tokens - assert prompt_len == seq_len + if seq_len >= len(req_state.prompt_token_ids): + # Transfer sampled tokens from TPU to CPU + token_id = selected_token_ids.cpu()[prompt_len - 1].item() + sampled_token_ids_list[req_index] = token_id - # UPDATE REQUEST STATE. - req_idx = self.input_batch.req_id_to_index[req_id] - self.input_batch.token_ids_cpu[req_idx, seq_len] = token_id - self.input_batch.num_tokens[req_idx] += 1 - req_state.output_token_ids.append(token_id) + # Update cached state + self.input_batch.token_ids_cpu[req_index, seq_len] = token_id + self.input_batch.num_tokens[req_index] += 1 + req_state.output_token_ids.append(token_id) - # num_reqs entries should be non-None + # Get req_ids assert all( req_id is not None for req_id in self.input_batch.req_ids[:num_reqs]), "req_ids contains None"