Chunked prompt works!

Signed-off-by: Alexander Matveev <amatveev@redhat.com>
This commit is contained in:
Alexander Matveev 2025-01-28 20:55:45 +00:00
parent 0bddb6b9a5
commit 61bb55f3d5
4 changed files with 406 additions and 466 deletions

View File

@ -213,11 +213,12 @@ class Scheduler:
num_new_tokens = self.block_size num_new_tokens = self.block_size
computed_blocks.pop() computed_blocks.pop()
# TODO: Remove
# If chunked prefill is not enabled, then breakout of the loop # If chunked prefill is not enabled, then breakout of the loop
# when above budget. # when above budget.
if (not self.scheduler_config.chunked_prefill_enabled # if (not self.scheduler_config.chunked_prefill_enabled
and num_new_tokens > token_budget): # and num_new_tokens > token_budget):
break # break
num_new_tokens = min(num_new_tokens, token_budget) num_new_tokens = min(num_new_tokens, token_budget)
assert num_new_tokens > 0 assert num_new_tokens > 0

View File

@ -9,23 +9,18 @@ import torch.distributed
from vllm.config import CompilationLevel, VllmConfig from vllm.config import CompilationLevel, VllmConfig
from vllm.distributed.parallel_state import graph_capture from vllm.distributed.parallel_state import graph_capture
from vllm.forward_context import set_forward_context from vllm.forward_context import set_forward_context
from vllm.inputs import INPUT_REGISTRY
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
from vllm.model_executor.model_loader import get_model from vllm.model_executor.model_loader import get_model
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
from vllm.multimodal.utils import group_mm_inputs_by_modality from vllm.multimodal.utils import group_mm_inputs_by_modality
from vllm.sampling_params import SamplingType
from vllm.utils import DeviceMemoryProfiler, cdiv from vllm.utils import DeviceMemoryProfiler, cdiv
from vllm.v1.attention.backends.flash_attn import (FlashAttentionBackend, from vllm.v1.attention.backends.flash_attn import (FlashAttentionBackend,
FlashAttentionMetadata) FlashAttentionMetadata)
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
from vllm.v1.engine.mm_input_mapper import MMInputMapperClient
from vllm.v1.kv_cache_interface import FullAttentionSpec, KVCacheConfig from vllm.v1.kv_cache_interface import FullAttentionSpec, KVCacheConfig
from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.utils import bind_kv_cache from vllm.v1.utils import bind_kv_cache
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
from vllm.v1.worker.model_runner_base import ExecutionMode, ModelRunnerBase from vllm.v1.worker.model_runner_base import ExecutionMode, ModelRunnerBase
if TYPE_CHECKING: if TYPE_CHECKING:
@ -43,41 +38,9 @@ class GPUModelRunner(ModelRunnerBase):
): ):
super().__init__(vllm_config, device) super().__init__(vllm_config, device)
# Persistent batch.
self.input_batch = InputBatch(
max_num_reqs=self.max_num_reqs,
max_model_len=self.max_model_len,
max_num_blocks_per_req=self.max_num_blocks_per_req,
device=self.device,
pin_memory=self.pin_memory,
vocab_size=self.model_config.get_vocab_size(),
)
# Request states.
self.requests: Dict[str, CachedRequestState] = {}
# KV caches for forward pass # KV caches for forward pass
self.kv_caches: List[torch.Tensor] = [] self.kv_caches: List[torch.Tensor] = []
# Multi-modal data support
self.input_registry = INPUT_REGISTRY
self.mm_registry = MULTIMODAL_REGISTRY
# NOTE: Initialized input mapper is only used for processing dummy
# multimodal data into multimodal kwargs for GPU memory profiling.
self.mm_input_mapper_profiling = MMInputMapperClient(self.model_config)
self.mm_input_mapper_profiling.use_cache = False
encoder_compute_budget, encoder_cache_size = compute_encoder_budget(
model_config=self.model_config,
scheduler_config=self.scheduler_config,
)
self.max_num_encoder_input_tokens = encoder_compute_budget
self.encoder_cache_size = encoder_cache_size
# req_id -> (input_id -> encoder_output)
self.encoder_cache: Dict[str, Dict[int, torch.Tensor]] = {}
self.use_cuda_graph = (self.vllm_config.compilation_config.level self.use_cuda_graph = (self.vllm_config.compilation_config.level
== CompilationLevel.PIECEWISE == CompilationLevel.PIECEWISE
and not self.model_config.enforce_eager) and not self.model_config.enforce_eager)
@ -160,132 +123,6 @@ class GPUModelRunner(ModelRunnerBase):
pin_memory=self.pin_memory) pin_memory=self.pin_memory)
self.seq_lens_np = self.seq_lens_cpu.numpy() self.seq_lens_np = self.seq_lens_cpu.numpy()
def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
# Remove stopped requests from the cached states.
# Keep the states of the pre-empted requests.
for req_id in scheduler_output.finished_req_ids:
self.requests.pop(req_id, None)
self.encoder_cache.pop(req_id, None)
# Free the cached encoder outputs.
for req_id, input_id in scheduler_output.free_encoder_input_ids:
encoder_outputs = self.encoder_cache.get(req_id)
if encoder_outputs is not None:
encoder_outputs.pop(input_id, None)
if not encoder_outputs:
self.encoder_cache.pop(req_id, None)
# Remove the requests from the persistent batch.
stopped_req_ids = set().union(
scheduler_output.preempted_req_ids,
scheduler_output.finished_req_ids,
)
removed_req_indices: List[int] = []
for req_id in stopped_req_ids:
req_index = self.input_batch.remove_request(req_id)
if req_index is not None:
removed_req_indices.append(req_index)
# Update the states of the running requests.
for req_data in scheduler_output.scheduled_running_reqs:
req_id = req_data.req_id
req_state = self.requests[req_id]
req_index = self.input_batch.req_id_to_index[req_id]
# Update the num_computed_tokens.
req_state.num_computed_tokens = req_data.num_computed_tokens
self.input_batch.num_computed_tokens_cpu[req_index] = (
req_data.num_computed_tokens)
# Update the block table.
num_new_blocks = len(req_data.new_block_ids)
if num_new_blocks == 0:
continue
start_index = len(req_state.block_ids)
req_state.block_ids.extend(req_data.new_block_ids)
self.input_batch.block_table.append_row(req_index, start_index,
req_data.new_block_ids)
req_ids_to_add: List[str] = []
# Add new requests to the cached states.
for new_req_data in scheduler_output.scheduled_new_reqs:
req_id = new_req_data.req_id
sampling_params = new_req_data.sampling_params
if sampling_params.sampling_type == SamplingType.RANDOM_SEED:
generator = torch.Generator(device=self.device)
generator.manual_seed(sampling_params.seed)
else:
generator = None
self.requests[req_id] = CachedRequestState(
req_id=req_id,
prompt_token_ids=new_req_data.prompt_token_ids,
prompt=new_req_data.prompt,
mm_inputs=new_req_data.mm_inputs,
mm_positions=new_req_data.mm_positions,
sampling_params=sampling_params,
generator=generator,
block_ids=new_req_data.block_ids,
num_computed_tokens=new_req_data.num_computed_tokens,
output_token_ids=[],
)
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
if self.model_config.uses_mrope:
image_grid_thw = []
video_grid_thw = []
for mm_input in self.requests[req_id].mm_inputs:
if mm_input.get("image_grid_thw") is not None:
image_grid_thw.extend(
mm_input["image_grid_thw"].tolist())
if mm_input.get("video_grid_thw") is not None:
video_grid_thw.extend(
mm_input["video_grid_thw"].tolist())
hf_config = self.model_config.hf_config
self.requests[req_id].mrope_positions, \
self.requests[req_id].mrope_position_delta = \
MRotaryEmbedding.get_input_positions_tensor(
self.requests[req_id].prompt_token_ids,
image_grid_thw=image_grid_thw,
video_grid_thw=video_grid_thw,
image_token_id=hf_config.image_token_id,
video_token_id=hf_config.video_token_id,
vision_start_token_id=hf_config.vision_start_token_id,
vision_end_token_id=hf_config.vision_end_token_id,
spatial_merge_size=hf_config.vision_config.
spatial_merge_size,
)
req_ids_to_add.append(req_id)
# Update the cached states of the resumed requests.
for res_req_data in scheduler_output.scheduled_resumed_reqs:
req_id = res_req_data.req_id
req_state = self.requests[req_id]
req_state.block_ids = res_req_data.block_ids
req_state.num_computed_tokens = res_req_data.num_computed_tokens
req_ids_to_add.append(req_id)
# Add the new or resumed requests to the persistent batch.
# The smaller empty indices are filled first.
removed_req_indices = sorted(removed_req_indices, reverse=True)
for req_id in req_ids_to_add:
req_state = self.requests[req_id]
if removed_req_indices:
# Fill the empty index.
req_index = removed_req_indices.pop()
else:
# Append to the end.
req_index = None
self.input_batch.add_request(req_state, req_index)
# Condense the batched states if there are empty indices.
if removed_req_indices:
self.input_batch.condense(removed_req_indices)
def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): def _prepare_inputs(self, scheduler_output: "SchedulerOutput"):
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
assert total_num_scheduled_tokens > 0 assert total_num_scheduled_tokens > 0
@ -665,7 +502,7 @@ class GPUModelRunner(ModelRunnerBase):
) -> ModelRunnerOutput: ) -> ModelRunnerOutput:
assert self.model is not None assert self.model is not None
self._update_states(scheduler_output) self.update_states(scheduler_output)
if self.is_multimodal_model: if self.is_multimodal_model:
# Run the multimodal encoder if any. # Run the multimodal encoder if any.

View File

@ -1,5 +1,5 @@
import enum import enum
from typing import TYPE_CHECKING, Optional from typing import TYPE_CHECKING, Dict, List, Optional
import torch import torch
import torch.distributed import torch.distributed
@ -8,11 +8,18 @@ import torch.nn as nn
from vllm.attention.backends.abstract import AttentionType from vllm.attention.backends.abstract import AttentionType
from vllm.attention.layer import Attention from vllm.attention.layer import Attention
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.inputs import INPUT_REGISTRY
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.sampling_params import SamplingType
from vllm.utils import LayerBlockType, cdiv, is_pin_memory_available from vllm.utils import LayerBlockType, cdiv, is_pin_memory_available
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
from vllm.v1.engine.mm_input_mapper import MMInputMapperClient
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheSpec) KVCacheSpec)
from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.v1.core.scheduler import SchedulerOutput from vllm.v1.core.scheduler import SchedulerOutput
@ -75,6 +82,164 @@ class ModelRunnerBase:
self.model: Optional[nn.Module] = None self.model: Optional[nn.Module] = None
# Persistent batch.
self.input_batch = InputBatch(
max_num_reqs=self.max_num_reqs,
max_model_len=self.max_model_len,
max_num_blocks_per_req=self.max_num_blocks_per_req,
device=self.device,
pin_memory=self.pin_memory,
vocab_size=self.model_config.get_vocab_size(),
)
# Request states.
self.requests: Dict[str, CachedRequestState] = {}
# Multi-modal data support
self.input_registry = INPUT_REGISTRY
self.mm_registry = MULTIMODAL_REGISTRY
# NOTE: Initialized input mapper is only used for processing dummy
# multimodal data into multimodal kwargs for GPU memory profiling.
self.mm_input_mapper_profiling = MMInputMapperClient(self.model_config)
self.mm_input_mapper_profiling.use_cache = False
encoder_compute_budget, encoder_cache_size = compute_encoder_budget(
model_config=self.model_config,
scheduler_config=self.scheduler_config,
)
self.max_num_encoder_input_tokens = encoder_compute_budget
self.encoder_cache_size = encoder_cache_size
# req_id -> (input_id -> encoder_output)
self.encoder_cache: Dict[str, Dict[int, torch.Tensor]] = {}
def update_states(self, scheduler_output: "SchedulerOutput") -> None:
# Remove stopped requests from the cached states.
# Keep the states of the pre-empted requests.
for req_id in scheduler_output.finished_req_ids:
self.requests.pop(req_id, None)
self.encoder_cache.pop(req_id, None)
# Free the cached encoder outputs.
for req_id, input_id in scheduler_output.free_encoder_input_ids:
encoder_outputs = self.encoder_cache.get(req_id)
if encoder_outputs is not None:
encoder_outputs.pop(input_id, None)
if not encoder_outputs:
self.encoder_cache.pop(req_id, None)
# Remove the requests from the persistent batch.
stopped_req_ids = set().union(
scheduler_output.preempted_req_ids,
scheduler_output.finished_req_ids,
)
removed_req_indices: List[int] = []
for req_id in stopped_req_ids:
req_index = self.input_batch.remove_request(req_id)
if req_index is not None:
removed_req_indices.append(req_index)
# Update the states of the running requests.
for req_data in scheduler_output.scheduled_running_reqs:
req_id = req_data.req_id
req_state = self.requests[req_id]
req_index = self.input_batch.req_id_to_index[req_id]
# Update the num_computed_tokens.
req_state.num_computed_tokens = req_data.num_computed_tokens
self.input_batch.num_computed_tokens_cpu[req_index] = (
req_data.num_computed_tokens)
# Update the block table.
num_new_blocks = len(req_data.new_block_ids)
if num_new_blocks == 0:
continue
start_index = len(req_state.block_ids)
req_state.block_ids.extend(req_data.new_block_ids)
self.input_batch.block_table.append_row(req_index, start_index,
req_data.new_block_ids)
req_ids_to_add: List[str] = []
# Add new requests to the cached states.
for new_req_data in scheduler_output.scheduled_new_reqs:
req_id = new_req_data.req_id
sampling_params = new_req_data.sampling_params
if sampling_params.sampling_type == SamplingType.RANDOM_SEED:
generator = torch.Generator(device=self.device)
generator.manual_seed(sampling_params.seed)
else:
generator = None
self.requests[req_id] = CachedRequestState(
req_id=req_id,
prompt_token_ids=new_req_data.prompt_token_ids,
prompt=new_req_data.prompt,
mm_inputs=new_req_data.mm_inputs,
mm_positions=new_req_data.mm_positions,
sampling_params=sampling_params,
generator=generator,
block_ids=new_req_data.block_ids,
num_computed_tokens=new_req_data.num_computed_tokens,
output_token_ids=[],
)
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
if self.model_config.uses_mrope:
image_grid_thw = []
video_grid_thw = []
for mm_input in self.requests[req_id].mm_inputs:
if mm_input.get("image_grid_thw") is not None:
image_grid_thw.extend(
mm_input["image_grid_thw"].tolist())
if mm_input.get("video_grid_thw") is not None:
video_grid_thw.extend(
mm_input["video_grid_thw"].tolist())
hf_config = self.model_config.hf_config
self.requests[req_id].mrope_positions, \
self.requests[req_id].mrope_position_delta = \
MRotaryEmbedding.get_input_positions_tensor(
self.requests[req_id].prompt_token_ids,
image_grid_thw=image_grid_thw,
video_grid_thw=video_grid_thw,
image_token_id=hf_config.image_token_id,
video_token_id=hf_config.video_token_id,
vision_start_token_id=hf_config.vision_start_token_id,
vision_end_token_id=hf_config.vision_end_token_id,
spatial_merge_size=hf_config.vision_config.
spatial_merge_size,
)
req_ids_to_add.append(req_id)
# Update the cached states of the resumed requests.
for res_req_data in scheduler_output.scheduled_resumed_reqs:
req_id = res_req_data.req_id
req_state = self.requests[req_id]
req_state.block_ids = res_req_data.block_ids
req_state.num_computed_tokens = res_req_data.num_computed_tokens
req_ids_to_add.append(req_id)
# Add the new or resumed requests to the persistent batch.
# The smaller empty indices are filled first.
removed_req_indices = sorted(removed_req_indices, reverse=True)
for req_id in req_ids_to_add:
req_state = self.requests[req_id]
if removed_req_indices:
# Fill the empty index.
req_index = removed_req_indices.pop()
else:
# Append to the end.
req_index = None
self.input_batch.add_request(req_state, req_index)
# Condense the batched states if there are empty indices.
if removed_req_indices:
self.input_batch.condense(removed_req_indices)
def get_model(self) -> nn.Module: def get_model(self) -> nn.Module:
assert self.model is not None assert self.model is not None
return self.model return self.model

View File

@ -15,7 +15,6 @@ from vllm.config import VllmConfig
from vllm.forward_context import set_forward_context from vllm.forward_context import set_forward_context
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model from vllm.model_executor.model_loader import get_model
from vllm.sampling_params import SamplingType
from vllm.v1.attention.backends.pallas import (PallasAttentionBackend, from vllm.v1.attention.backends.pallas import (PallasAttentionBackend,
PallasMetadata) PallasMetadata)
from vllm.v1.kv_cache_interface import FullAttentionSpec, KVCacheConfig from vllm.v1.kv_cache_interface import FullAttentionSpec, KVCacheConfig
@ -40,25 +39,24 @@ _MAX_NUM_SAMPLES = 128
@dataclass @dataclass
class PrefillInputData: class PromptInputData:
request_ids: List req_ids: List
prompt_lens: List prompt_lens: List
token_ids: List input_tokens: List
position_ids: List input_positions: List
attn_metadata: List attn_metadata: List
def zipped(self): def zipped(self):
return zip(self.request_ids, self.prompt_lens, self.token_ids, return zip(self.req_ids, self.prompt_lens, self.input_tokens,
self.position_ids, self.attn_metadata) self.input_positions, self.attn_metadata)
@dataclass @dataclass
class DecodeInputData: class DecodeInputData:
req_ids: List
num_decodes: int input_tokens: Optional[torch.Tensor] = None
token_ids: Optional[torch.Tensor] = None input_positions: Optional[torch.Tensor] = None
position_ids: Optional[torch.Tensor] = None
attn_metadata: Optional[PallasMetadata] = None attn_metadata: Optional[PallasMetadata] = None
@ -88,158 +86,105 @@ class TPUModelRunner(ModelRunnerBase):
self.kv_caches: List[Tuple[torch.Tensor, torch.Tensor]] = [] self.kv_caches: List[Tuple[torch.Tensor, torch.Tensor]] = []
# Used to initialize positions for the individual prefills # Used to initialize positions for the individual prefills
self.prefill_positions = torch.tensor(range(self.max_model_len), self.prefill_input_positions = torch.tensor(range(self.max_model_len),
device="cpu", device="cpu",
dtype=torch.int32).reshape( dtype=torch.int32).reshape(
1, -1) 1, -1)
# Used to indicate how many prefills there are for each scheduler def _prepare_prompt_inputs(
# iteration
self.num_new_reqs: int = 0
def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
# Remove stopped requests from the cached states.
# Keep the states of the pre-empted requests.
for req_id in scheduler_output.finished_req_ids:
self.requests.pop(req_id, None)
# Remove the requests from the persistent batch.
stopped_req_ids = set().union(
scheduler_output.preempted_req_ids,
scheduler_output.finished_req_ids,
)
removed_req_indices: List[int] = []
for req_id in stopped_req_ids:
req_index = self.input_batch.remove_request(req_id)
if req_index is not None:
removed_req_indices.append(req_index)
# Update the states of the running requests.
for req_data in scheduler_output.scheduled_running_reqs:
req_id = req_data.req_id
req_state = self.requests[req_id]
req_index = self.input_batch.req_id_to_index[req_id]
# Update the num_computed_tokens.
req_state.num_computed_tokens = req_data.num_computed_tokens
self.input_batch.num_computed_tokens_cpu[req_index] = (
req_data.num_computed_tokens)
# Update the block table.
num_new_blocks = len(req_data.new_block_ids)
if num_new_blocks == 0:
continue
start_index = len(req_state.block_ids)
req_state.block_ids.extend(req_data.new_block_ids)
self.input_batch.block_table.append_row(req_index, start_index,
req_data.new_block_ids)
req_ids_to_add: List[str] = []
# Add new requests to the cached states.
for new_req_data in scheduler_output.scheduled_new_reqs:
req_id = new_req_data.req_id
sampling_params = new_req_data.sampling_params
if sampling_params.sampling_type == SamplingType.RANDOM_SEED:
generator = torch.Generator(device=self.device)
generator.manual_seed(sampling_params.seed)
else:
generator = None
self.requests[req_id] = CachedRequestState(
req_id=req_id,
prompt_token_ids=new_req_data.prompt_token_ids,
prompt=new_req_data.prompt,
mm_inputs=new_req_data.mm_inputs,
mm_positions=new_req_data.mm_positions,
sampling_params=sampling_params,
generator=generator,
block_ids=new_req_data.block_ids,
num_computed_tokens=new_req_data.num_computed_tokens,
output_token_ids=[],
)
req_ids_to_add.append(req_id)
# Update the cached states of the resumed requests.
for res_req_data in scheduler_output.scheduled_resumed_reqs:
req_id = res_req_data.req_id
req_state = self.requests[req_id]
req_state.block_ids = res_req_data.block_ids
req_state.num_computed_tokens = res_req_data.num_computed_tokens
req_ids_to_add.append(req_id)
# For TPU, we keep all of the decode requests before the
# prefill requests in the batch sequence.
# 1. First condense, so all decodes move to start
# 2. Then add new prefills to the end of the batch
removed_req_indices = sorted(removed_req_indices, reverse=True)
if removed_req_indices:
self.input_batch.condense(removed_req_indices)
for req_id in req_ids_to_add:
req_state = self.requests[req_id]
self.input_batch.add_request(req_state, None) # Append last
self.num_new_reqs = len(req_ids_to_add)
def _prepare_prefill_inputs(
self, self,
num_scheduled_tokens: List[int], scheduler_output: "SchedulerOutput",
) -> PrefillInputData: ) -> PromptInputData:
# Each prefill run separately with shape [1, padded_prompt_len]. total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
# So we create lists that will be used in execute_model(). assert total_num_scheduled_tokens > 0
prefill_request_ids = []
prefill_prompt_lens = []
prefill_token_ids = []
prefill_position_ids = []
prefill_attn_metadata = []
# DECODES are the first num_decodes REQUESTS.
# PREFILLS are the next num_reqs - num_decodes REQUESTS.
num_reqs = self.input_batch.num_reqs num_reqs = self.input_batch.num_reqs
num_decodes = num_reqs - self.num_new_reqs assert num_reqs > 0
for idx in range(num_decodes, num_reqs):
req_id = self.input_batch.req_ids[idx]
prefill_request_ids.append(req_id)
prompt_len = num_scheduled_tokens[idx] # OPTIMIZATION: Start copying the block table first.
prefill_prompt_lens.append(prompt_len) # This way, we can overlap the copy with the following CPU operations.
self.input_batch.block_table.commit(num_reqs)
# STATIC SHAPE: prefills are padded to the next power of 2. req_ids = []
prompt_lens = []
input_tokens_list = []
input_positions_list = []
attn_metadata_list = []
for req_id in self.input_batch.req_ids[:num_reqs]:
assert req_id is not None
req_index = self.input_batch.req_id_to_index[req_id]
req_state = self.requests[req_id]
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[
req_id]
num_computed_tokens = req_state.num_computed_tokens
num_prompt_tokens = len(req_state.prompt_token_ids)
# Detect whether this is a prompt (can be full or chunked)
if num_computed_tokens >= num_prompt_tokens:
# This is a decode => Skip
continue
# This is a prompt
req_ids.append(req_id)
# Prompt len
prompt_len = num_scheduled_tokens
prompt_lens.append(prompt_len)
padded_prompt_len = _get_padded_prefill_len(prompt_len) padded_prompt_len = _get_padded_prefill_len(prompt_len)
assert padded_prompt_len <= self.max_model_len assert padded_prompt_len <= self.max_model_len
# TOKEN_IDS. # Seq len
token_ids = torch.from_numpy(self.input_batch.token_ids_cpu[ seq_len = num_computed_tokens + prompt_len
idx, :padded_prompt_len].reshape(1, -1)) padded_seq_len = num_computed_tokens + padded_prompt_len
token_ids[:, prompt_len:] = 0
prefill_token_ids.append(token_ids.to(self.device))
# POSITIONS. # Input tokens
positions = self.prefill_positions[:, :padded_prompt_len].clone() input_tokens = torch.from_numpy(self.input_batch.token_ids_cpu[
positions[:, prompt_len:] = 0 req_index, num_computed_tokens:padded_seq_len].reshape(1, -1))
prefill_position_ids.append(positions.to(self.device)) input_tokens[:, prompt_len:] = 0
input_tokens_list.append(input_tokens.to(self.device))
# SLOT_MAPPING. # Input positions
# The "slot" is the "physical index" of a token in the KV cache. input_positions = self.prefill_input_positions[:,
# Look up the block_idx in the block table (logical<>physical map) num_computed_tokens:
# to compute this. padded_seq_len].clone(
)
input_positions[:, prompt_len:] = 0
input_positions_list.append(input_positions.to(self.device))
# Slot mapping
block_table_cpu_tensor = \ block_table_cpu_tensor = \
self.input_batch.block_table.get_cpu_tensor() self.input_batch.block_table.get_cpu_tensor()
block_numbers = block_table_cpu_tensor[req_index,
block_numbers = block_table_cpu_tensor[idx, positions // input_positions //
self.block_size].reshape( self.block_size].reshape(
1, -1) 1, -1)
block_offsets = positions % self.block_size block_offsets = input_positions % self.block_size
slot_mapping = block_numbers * self.block_size + block_offsets slot_mapping = block_numbers * self.block_size + block_offsets
# Set an out of range value for the padding tokens so that they
# are ignored when inserting into the KV cache.
slot_mapping[:, prompt_len:] = _PAD_SLOT_ID slot_mapping[:, prompt_len:] = _PAD_SLOT_ID
slot_mapping = slot_mapping.long() slot_mapping = slot_mapping.long()
prefill_attn_metadata.append( # Block table
block_table = None
if num_computed_tokens > 0:
block_table = self.input_batch.block_table.get_device_tensor()
block_table = block_table[req_index].unsqueeze(0)
# Context len
context_len = 0
if num_computed_tokens > 0:
context_len = seq_len
context_lens = torch.tensor([context_len],
dtype=torch.int32,
device="cpu")
# Effective query len
effective_query_lens = torch.tensor([prompt_len],
dtype=torch.int32,
device="cpu")
# Attn metadata
attn_metadata_list.append(
PallasMetadata( PallasMetadata(
num_prefills=1, num_prefills=1,
num_prefill_tokens=0, # NOTE: This is not used. num_prefill_tokens=0, # NOTE: This is not used.
@ -247,208 +192,200 @@ class TPUModelRunner(ModelRunnerBase):
slot_mapping=slot_mapping.to(self.device), slot_mapping=slot_mapping.to(self.device),
multi_modal_placeholder_index_maps=None, multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=True, enable_kv_scales_calculation=True,
block_tables=None, block_tables=block_table,
context_lens=None, context_lens=context_lens.to(self.device),
effective_query_lens=None, effective_query_lens=effective_query_lens.to(self.device),
)) ))
return PrefillInputData( return PromptInputData(
request_ids=prefill_request_ids, req_ids=req_ids,
prompt_lens=prefill_prompt_lens, prompt_lens=prompt_lens,
token_ids=prefill_token_ids, input_tokens=input_tokens_list,
position_ids=prefill_position_ids, input_positions=input_positions_list,
attn_metadata=prefill_attn_metadata, attn_metadata=attn_metadata_list,
) )
def _prepare_decode_inputs(self) -> DecodeInputData: def _prepare_decode_inputs(
# Decodes run as one single padded batch with shape [batch, 1] self,
# scheduler_output: "SchedulerOutput",
# We need to set _PAD_SLOT_ID for the padding tokens in the ) -> DecodeInputData:
# slot_mapping, such that the attention KV cache insertion
# logic knows to ignore those indices. Otherwise, the
# padding data can be dummy since we have a causal mask.
# DECODES are the first num_decodes REQUESTS.
# PREFILLS are the next num_reqs - num_decodes REQUESTS.
num_reqs = self.input_batch.num_reqs
num_decodes = num_reqs - self.num_new_reqs
if num_decodes == 0:
return DecodeInputData(num_decodes=0)
# PAD FOR STATIC SHAPES.
padded_batch_size = _get_padded_batch_size(num_decodes)
# POSITIONS. [batch, 1]
# We slice at the end, since we use the positions for gathering.
positions = torch.from_numpy(
self.input_batch.num_computed_tokens_cpu.reshape(-1, 1))
index = positions.to(torch.int64)
index[num_decodes:] = 0
positions = positions[:padded_batch_size]
positions[num_decodes:] = 0
# TOKEN_IDS. [batch, 1]
token_ids = torch.gather(
input=torch.from_numpy(self.input_batch.token_ids_cpu),
dim=1,
index=index,
)[:padded_batch_size].to(torch.int32)
token_ids[num_decodes:] = 0
# SLOT_MAPPING [batch, 1]
# The "slot" is the "physical index" of a token in the KV cache.
# Look up the block_idx in the block table (logical<>physical map)
# to compute this.
block_table_cpu_tensor = self.input_batch.block_table.get_cpu_tensor()
block_number = torch.gather(input=block_table_cpu_tensor,
dim=1,
index=(index // self.block_size))
block_offsets = index % self.block_size
slot_mapping = block_number * self.block_size + block_offsets
# Set an out of range value for the padding tokens so that they
# are ignored when inserting into the KV cache.
slot_mapping[num_decodes:] = _PAD_SLOT_ID
slot_mapping = slot_mapping[:padded_batch_size]
slot_mapping = slot_mapping.long()
# BLOCK_TABLE [batch, max_num_blocks_per_req]
block_table = block_table_cpu_tensor[:padded_batch_size]
# CONTEXT_LENS [batch_size]
context_lens = (positions.reshape(-1) + 1)
context_lens[num_decodes:] = 0
# CPU<>TPU sync happens here.
return DecodeInputData(num_decodes=num_decodes,
token_ids=token_ids.to(self.device),
position_ids=positions.to(self.device),
attn_metadata=PallasMetadata(
num_prefills=0,
num_prefill_tokens=0,
num_decode_tokens=padded_batch_size,
slot_mapping=slot_mapping.to(self.device),
multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=True,
block_tables=block_table.to(self.device),
context_lens=context_lens.to(self.device),
effective_query_lens=None,
))
def _prepare_inputs(self, scheduler_output: "SchedulerOutput"):
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
assert total_num_scheduled_tokens > 0 assert total_num_scheduled_tokens > 0
num_reqs = self.input_batch.num_reqs num_reqs = self.input_batch.num_reqs
assert num_reqs > 0 assert num_reqs > 0
num_decodes = num_reqs - self.num_new_reqs block_table_cpu_tensor = self.input_batch.block_table.get_cpu_tensor()
# TODO: Resurrect req_ids = []
# OPTIMIZATION: Start copying the block table first. req_indices = []
# This way, we can overlap the copy with the following CPU operations. input_tokens = []
# TODO: Verify this works with TPUs input_positions = []
# self.input_batch.block_table.commit(num_reqs) slot_mapping = []
context_lens = []
# Get the number of scheduled tokens for each request. for req_id in self.input_batch.req_ids[:num_reqs]:
# TODO: The Python loop can be slow. Optimize.
num_scheduled_tokens = []
max_num_scheduled_tokens = 0
for idx, req_id in enumerate(self.input_batch.req_ids[:num_reqs]):
assert req_id is not None assert req_id is not None
num_tokens = scheduler_output.num_scheduled_tokens[req_id] req_index = self.input_batch.req_id_to_index[req_id]
num_scheduled_tokens.append(num_tokens) req_state = self.requests[req_id]
max_num_scheduled_tokens = max(max_num_scheduled_tokens,
num_tokens)
# NOTE: Assert that all the decodes are "decodes". num_scheduled_tokens = scheduler_output.num_scheduled_tokens[
if idx < num_decodes: req_id]
assert num_tokens == 1 num_computed_tokens = req_state.num_computed_tokens
num_prompt_tokens = len(req_state.prompt_token_ids)
assert max_num_scheduled_tokens > 0 # Detect whether this is a decode
if num_computed_tokens < num_prompt_tokens:
# This is a prompt => Skip
continue
return ( # This is a decode
self._prepare_prefill_inputs(num_scheduled_tokens), req_ids.append(req_id)
self._prepare_decode_inputs(), req_indices.append(req_index)
# Seq len
seq_len = num_computed_tokens + num_scheduled_tokens
# Sanity check decode
assert num_scheduled_tokens == 1
assert seq_len == req_state.num_tokens
# Input token
input_tokens.append([
self.input_batch.token_ids_cpu[req_index, num_computed_tokens]
])
# Position
input_positions.append([num_computed_tokens])
# Slot mapping
block_number = block_table_cpu_tensor[req_index,
num_computed_tokens //
self.block_size]
block_offset = num_computed_tokens % self.block_size
slot_id = block_number * self.block_size + block_offset
slot_mapping.append([slot_id])
# Context len
context_lens.append(seq_len)
# Compute padding
batch_size = len(input_tokens)
padded_batch_size = _get_padded_batch_size(batch_size)
num_padding = padded_batch_size - batch_size
# Add padding
input_tokens.extend([[0]] * num_padding)
input_positions.extend([[0]] * num_padding)
slot_mapping.extend([[_PAD_SLOT_ID]] * num_padding)
context_lens.extend([0] * num_padding)
req_indices.extend([0] * num_padding)
# Create tensors
input_tokens_tensor = torch.tensor(input_tokens,
dtype=torch.int32,
device="cpu")
input_positions_tensor = torch.tensor(input_positions,
dtype=torch.int32,
device="cpu")
slot_mapping_tensor = torch.tensor(slot_mapping,
dtype=torch.int64,
device="cpu")
context_lens_tensor = torch.tensor(context_lens,
dtype=torch.int32,
device="cpu")
block_tables_tensor = block_table_cpu_tensor[req_indices]
# Attn metadata
attn_metadata = PallasMetadata(
num_prefills=0,
num_prefill_tokens=0,
num_decode_tokens=padded_batch_size,
slot_mapping=slot_mapping_tensor.to(self.device),
multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=True,
block_tables=block_tables_tensor.to(self.device),
context_lens=context_lens_tensor.to(self.device),
effective_query_lens=None,
) )
return DecodeInputData(
req_ids=req_ids,
input_tokens=input_tokens_tensor.to(self.device),
input_positions=input_positions_tensor.to(self.device),
attn_metadata=attn_metadata)
@torch.no_grad() @torch.no_grad()
def execute_model( def execute_model(
self, self,
scheduler_output: "SchedulerOutput", scheduler_output: "SchedulerOutput",
) -> ModelRunnerOutput: ) -> ModelRunnerOutput:
self._update_states(scheduler_output) # Update cached state
self.update_states(scheduler_output)
# Prepare the decoder inputs. # Prepare inputs
prefill_data, decode_data = self._prepare_inputs(scheduler_output) prompt_data = self._prepare_prompt_inputs(scheduler_output)
decode_data = self._prepare_decode_inputs(scheduler_output)
# Init
num_reqs = self.input_batch.num_reqs num_reqs = self.input_batch.num_reqs
assert num_reqs > 0
sampled_token_ids_list = [0] * num_reqs
######################### DECODES ######################### # Run decodes (a single batch)
# Decodes run as one single batch with [padded_batch, 1] if len(decode_data.req_ids) > 0:
sampled_token_ids_list = [] # Forward
if decode_data.num_decodes > 0:
# FORWARD.
with set_forward_context(decode_data.attn_metadata, with set_forward_context(decode_data.attn_metadata,
self.vllm_config): self.vllm_config):
assert self.model is not None assert self.model is not None
selected_token_ids = self.model(decode_data.token_ids, selected_token_ids = self.model(decode_data.input_tokens,
decode_data.position_ids, decode_data.input_positions,
decode_data.attn_metadata, decode_data.attn_metadata,
self.kv_caches) self.kv_caches)
# NOTE: TPU<>CPU sync happens here. # Transfer sampled tokens from TPU to CPU
# We need to call .cpu() first to avoid recompilation. selected_token_ids_list = selected_token_ids.cpu().tolist()
token_ids = selected_token_ids.cpu()[:decode_data.num_decodes]
sampled_token_ids_list.extend(token_ids.tolist())
# UPDATE REQUEST STATE. # Update cached state
for i, req_id in enumerate( for i, req_id in enumerate(decode_data.req_ids):
self.input_batch.req_ids[:decode_data.num_decodes]): req_index = self.input_batch.req_id_to_index[req_id]
assert req_id is not None
req_state = self.requests[req_id] req_state = self.requests[req_id]
assert scheduler_output.num_scheduled_tokens[req_id] == 1
seq_len = (req_state.num_computed_tokens + seq_len = (req_state.num_computed_tokens +
scheduler_output.num_scheduled_tokens[req_id]) scheduler_output.num_scheduled_tokens[req_id])
assert seq_len == req_state.num_tokens
token_id = sampled_token_ids_list[i] token_id = selected_token_ids_list[i]
self.input_batch.token_ids_cpu[i, seq_len] = token_id
self.input_batch.num_tokens[i] += 1 self.input_batch.token_ids_cpu[req_index, seq_len] = token_id
self.input_batch.num_tokens[req_index] += 1
req_state.output_token_ids.append(token_id) req_state.output_token_ids.append(token_id)
######################### PREFILLS ######################### sampled_token_ids_list[req_index] = token_id
# Prefills run separately with shape [1, padded_prefill_len],
# due to lack of variable length attention kernel so far.
for (req_id, prompt_len, token_ids, position_ids,
attn_metadata) in prefill_data.zipped():
assert req_id is not None
# FORWARD. # Run each prompt
for (req_id, prompt_len, input_tokens, input_positions,
attn_metadata) in prompt_data.zipped():
assert req_id is not None
req_state = self.requests[req_id]
req_index = self.input_batch.req_id_to_index[req_id]
# Forward
with set_forward_context(attn_metadata, self.vllm_config): with set_forward_context(attn_metadata, self.vllm_config):
assert self.model is not None assert self.model is not None
selected_token_ids = self.model(token_ids, position_ids, selected_token_ids = self.model(input_tokens, input_positions,
attn_metadata, self.kv_caches) attn_metadata, self.kv_caches)
# NOTE: TPU<>CPU sync happens here.
# We need to call .cpu() first to avoid recompilation.
token_id = selected_token_ids.cpu()[prompt_len - 1].item()
sampled_token_ids_list.append(token_id)
req_state = self.requests[req_id]
assert req_state.num_computed_tokens == 0
seq_len = (req_state.num_computed_tokens + seq_len = (req_state.num_computed_tokens +
scheduler_output.num_scheduled_tokens[req_id]) scheduler_output.num_scheduled_tokens[req_id])
assert seq_len == req_state.num_tokens if seq_len >= len(req_state.prompt_token_ids):
assert prompt_len == seq_len # Transfer sampled tokens from TPU to CPU
token_id = selected_token_ids.cpu()[prompt_len - 1].item()
sampled_token_ids_list[req_index] = token_id
# UPDATE REQUEST STATE. # Update cached state
req_idx = self.input_batch.req_id_to_index[req_id] self.input_batch.token_ids_cpu[req_index, seq_len] = token_id
self.input_batch.token_ids_cpu[req_idx, seq_len] = token_id self.input_batch.num_tokens[req_index] += 1
self.input_batch.num_tokens[req_idx] += 1 req_state.output_token_ids.append(token_id)
req_state.output_token_ids.append(token_id)
# num_reqs entries should be non-None # Get req_ids
assert all( assert all(
req_id is not None for req_id in req_id is not None for req_id in
self.input_batch.req_ids[:num_reqs]), "req_ids contains None" self.input_batch.req_ids[:num_reqs]), "req_ids contains None"