Chunked prompt works!

Signed-off-by: Alexander Matveev <amatveev@redhat.com>
This commit is contained in:
Alexander Matveev 2025-01-28 20:55:45 +00:00
parent 0bddb6b9a5
commit 61bb55f3d5
4 changed files with 406 additions and 466 deletions

View File

@ -213,11 +213,12 @@ class Scheduler:
num_new_tokens = self.block_size
computed_blocks.pop()
# TODO: Remove
# If chunked prefill is not enabled, then breakout of the loop
# when above budget.
if (not self.scheduler_config.chunked_prefill_enabled
and num_new_tokens > token_budget):
break
# if (not self.scheduler_config.chunked_prefill_enabled
# and num_new_tokens > token_budget):
# break
num_new_tokens = min(num_new_tokens, token_budget)
assert num_new_tokens > 0

View File

@ -9,23 +9,18 @@ import torch.distributed
from vllm.config import CompilationLevel, VllmConfig
from vllm.distributed.parallel_state import graph_capture
from vllm.forward_context import set_forward_context
from vllm.inputs import INPUT_REGISTRY
from vllm.logger import init_logger
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
from vllm.model_executor.model_loader import get_model
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
from vllm.multimodal.utils import group_mm_inputs_by_modality
from vllm.sampling_params import SamplingType
from vllm.utils import DeviceMemoryProfiler, cdiv
from vllm.v1.attention.backends.flash_attn import (FlashAttentionBackend,
FlashAttentionMetadata)
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
from vllm.v1.engine.mm_input_mapper import MMInputMapperClient
from vllm.v1.kv_cache_interface import FullAttentionSpec, KVCacheConfig
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.utils import bind_kv_cache
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
from vllm.v1.worker.model_runner_base import ExecutionMode, ModelRunnerBase
if TYPE_CHECKING:
@ -43,41 +38,9 @@ class GPUModelRunner(ModelRunnerBase):
):
super().__init__(vllm_config, device)
# Persistent batch.
self.input_batch = InputBatch(
max_num_reqs=self.max_num_reqs,
max_model_len=self.max_model_len,
max_num_blocks_per_req=self.max_num_blocks_per_req,
device=self.device,
pin_memory=self.pin_memory,
vocab_size=self.model_config.get_vocab_size(),
)
# Request states.
self.requests: Dict[str, CachedRequestState] = {}
# KV caches for forward pass
self.kv_caches: List[torch.Tensor] = []
# Multi-modal data support
self.input_registry = INPUT_REGISTRY
self.mm_registry = MULTIMODAL_REGISTRY
# NOTE: Initialized input mapper is only used for processing dummy
# multimodal data into multimodal kwargs for GPU memory profiling.
self.mm_input_mapper_profiling = MMInputMapperClient(self.model_config)
self.mm_input_mapper_profiling.use_cache = False
encoder_compute_budget, encoder_cache_size = compute_encoder_budget(
model_config=self.model_config,
scheduler_config=self.scheduler_config,
)
self.max_num_encoder_input_tokens = encoder_compute_budget
self.encoder_cache_size = encoder_cache_size
# req_id -> (input_id -> encoder_output)
self.encoder_cache: Dict[str, Dict[int, torch.Tensor]] = {}
self.use_cuda_graph = (self.vllm_config.compilation_config.level
== CompilationLevel.PIECEWISE
and not self.model_config.enforce_eager)
@ -160,132 +123,6 @@ class GPUModelRunner(ModelRunnerBase):
pin_memory=self.pin_memory)
self.seq_lens_np = self.seq_lens_cpu.numpy()
def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
# Remove stopped requests from the cached states.
# Keep the states of the pre-empted requests.
for req_id in scheduler_output.finished_req_ids:
self.requests.pop(req_id, None)
self.encoder_cache.pop(req_id, None)
# Free the cached encoder outputs.
for req_id, input_id in scheduler_output.free_encoder_input_ids:
encoder_outputs = self.encoder_cache.get(req_id)
if encoder_outputs is not None:
encoder_outputs.pop(input_id, None)
if not encoder_outputs:
self.encoder_cache.pop(req_id, None)
# Remove the requests from the persistent batch.
stopped_req_ids = set().union(
scheduler_output.preempted_req_ids,
scheduler_output.finished_req_ids,
)
removed_req_indices: List[int] = []
for req_id in stopped_req_ids:
req_index = self.input_batch.remove_request(req_id)
if req_index is not None:
removed_req_indices.append(req_index)
# Update the states of the running requests.
for req_data in scheduler_output.scheduled_running_reqs:
req_id = req_data.req_id
req_state = self.requests[req_id]
req_index = self.input_batch.req_id_to_index[req_id]
# Update the num_computed_tokens.
req_state.num_computed_tokens = req_data.num_computed_tokens
self.input_batch.num_computed_tokens_cpu[req_index] = (
req_data.num_computed_tokens)
# Update the block table.
num_new_blocks = len(req_data.new_block_ids)
if num_new_blocks == 0:
continue
start_index = len(req_state.block_ids)
req_state.block_ids.extend(req_data.new_block_ids)
self.input_batch.block_table.append_row(req_index, start_index,
req_data.new_block_ids)
req_ids_to_add: List[str] = []
# Add new requests to the cached states.
for new_req_data in scheduler_output.scheduled_new_reqs:
req_id = new_req_data.req_id
sampling_params = new_req_data.sampling_params
if sampling_params.sampling_type == SamplingType.RANDOM_SEED:
generator = torch.Generator(device=self.device)
generator.manual_seed(sampling_params.seed)
else:
generator = None
self.requests[req_id] = CachedRequestState(
req_id=req_id,
prompt_token_ids=new_req_data.prompt_token_ids,
prompt=new_req_data.prompt,
mm_inputs=new_req_data.mm_inputs,
mm_positions=new_req_data.mm_positions,
sampling_params=sampling_params,
generator=generator,
block_ids=new_req_data.block_ids,
num_computed_tokens=new_req_data.num_computed_tokens,
output_token_ids=[],
)
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
if self.model_config.uses_mrope:
image_grid_thw = []
video_grid_thw = []
for mm_input in self.requests[req_id].mm_inputs:
if mm_input.get("image_grid_thw") is not None:
image_grid_thw.extend(
mm_input["image_grid_thw"].tolist())
if mm_input.get("video_grid_thw") is not None:
video_grid_thw.extend(
mm_input["video_grid_thw"].tolist())
hf_config = self.model_config.hf_config
self.requests[req_id].mrope_positions, \
self.requests[req_id].mrope_position_delta = \
MRotaryEmbedding.get_input_positions_tensor(
self.requests[req_id].prompt_token_ids,
image_grid_thw=image_grid_thw,
video_grid_thw=video_grid_thw,
image_token_id=hf_config.image_token_id,
video_token_id=hf_config.video_token_id,
vision_start_token_id=hf_config.vision_start_token_id,
vision_end_token_id=hf_config.vision_end_token_id,
spatial_merge_size=hf_config.vision_config.
spatial_merge_size,
)
req_ids_to_add.append(req_id)
# Update the cached states of the resumed requests.
for res_req_data in scheduler_output.scheduled_resumed_reqs:
req_id = res_req_data.req_id
req_state = self.requests[req_id]
req_state.block_ids = res_req_data.block_ids
req_state.num_computed_tokens = res_req_data.num_computed_tokens
req_ids_to_add.append(req_id)
# Add the new or resumed requests to the persistent batch.
# The smaller empty indices are filled first.
removed_req_indices = sorted(removed_req_indices, reverse=True)
for req_id in req_ids_to_add:
req_state = self.requests[req_id]
if removed_req_indices:
# Fill the empty index.
req_index = removed_req_indices.pop()
else:
# Append to the end.
req_index = None
self.input_batch.add_request(req_state, req_index)
# Condense the batched states if there are empty indices.
if removed_req_indices:
self.input_batch.condense(removed_req_indices)
def _prepare_inputs(self, scheduler_output: "SchedulerOutput"):
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
assert total_num_scheduled_tokens > 0
@ -665,7 +502,7 @@ class GPUModelRunner(ModelRunnerBase):
) -> ModelRunnerOutput:
assert self.model is not None
self._update_states(scheduler_output)
self.update_states(scheduler_output)
if self.is_multimodal_model:
# Run the multimodal encoder if any.

View File

@ -1,5 +1,5 @@
import enum
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Dict, List, Optional
import torch
import torch.distributed
@ -8,11 +8,18 @@ import torch.nn as nn
from vllm.attention.backends.abstract import AttentionType
from vllm.attention.layer import Attention
from vllm.config import VllmConfig
from vllm.inputs import INPUT_REGISTRY
from vllm.logger import init_logger
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.sampling_params import SamplingType
from vllm.utils import LayerBlockType, cdiv, is_pin_memory_available
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
from vllm.v1.engine.mm_input_mapper import MMInputMapperClient
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheSpec)
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
if TYPE_CHECKING:
from vllm.v1.core.scheduler import SchedulerOutput
@ -75,6 +82,164 @@ class ModelRunnerBase:
self.model: Optional[nn.Module] = None
# Persistent batch.
self.input_batch = InputBatch(
max_num_reqs=self.max_num_reqs,
max_model_len=self.max_model_len,
max_num_blocks_per_req=self.max_num_blocks_per_req,
device=self.device,
pin_memory=self.pin_memory,
vocab_size=self.model_config.get_vocab_size(),
)
# Request states.
self.requests: Dict[str, CachedRequestState] = {}
# Multi-modal data support
self.input_registry = INPUT_REGISTRY
self.mm_registry = MULTIMODAL_REGISTRY
# NOTE: Initialized input mapper is only used for processing dummy
# multimodal data into multimodal kwargs for GPU memory profiling.
self.mm_input_mapper_profiling = MMInputMapperClient(self.model_config)
self.mm_input_mapper_profiling.use_cache = False
encoder_compute_budget, encoder_cache_size = compute_encoder_budget(
model_config=self.model_config,
scheduler_config=self.scheduler_config,
)
self.max_num_encoder_input_tokens = encoder_compute_budget
self.encoder_cache_size = encoder_cache_size
# req_id -> (input_id -> encoder_output)
self.encoder_cache: Dict[str, Dict[int, torch.Tensor]] = {}
def update_states(self, scheduler_output: "SchedulerOutput") -> None:
# Remove stopped requests from the cached states.
# Keep the states of the pre-empted requests.
for req_id in scheduler_output.finished_req_ids:
self.requests.pop(req_id, None)
self.encoder_cache.pop(req_id, None)
# Free the cached encoder outputs.
for req_id, input_id in scheduler_output.free_encoder_input_ids:
encoder_outputs = self.encoder_cache.get(req_id)
if encoder_outputs is not None:
encoder_outputs.pop(input_id, None)
if not encoder_outputs:
self.encoder_cache.pop(req_id, None)
# Remove the requests from the persistent batch.
stopped_req_ids = set().union(
scheduler_output.preempted_req_ids,
scheduler_output.finished_req_ids,
)
removed_req_indices: List[int] = []
for req_id in stopped_req_ids:
req_index = self.input_batch.remove_request(req_id)
if req_index is not None:
removed_req_indices.append(req_index)
# Update the states of the running requests.
for req_data in scheduler_output.scheduled_running_reqs:
req_id = req_data.req_id
req_state = self.requests[req_id]
req_index = self.input_batch.req_id_to_index[req_id]
# Update the num_computed_tokens.
req_state.num_computed_tokens = req_data.num_computed_tokens
self.input_batch.num_computed_tokens_cpu[req_index] = (
req_data.num_computed_tokens)
# Update the block table.
num_new_blocks = len(req_data.new_block_ids)
if num_new_blocks == 0:
continue
start_index = len(req_state.block_ids)
req_state.block_ids.extend(req_data.new_block_ids)
self.input_batch.block_table.append_row(req_index, start_index,
req_data.new_block_ids)
req_ids_to_add: List[str] = []
# Add new requests to the cached states.
for new_req_data in scheduler_output.scheduled_new_reqs:
req_id = new_req_data.req_id
sampling_params = new_req_data.sampling_params
if sampling_params.sampling_type == SamplingType.RANDOM_SEED:
generator = torch.Generator(device=self.device)
generator.manual_seed(sampling_params.seed)
else:
generator = None
self.requests[req_id] = CachedRequestState(
req_id=req_id,
prompt_token_ids=new_req_data.prompt_token_ids,
prompt=new_req_data.prompt,
mm_inputs=new_req_data.mm_inputs,
mm_positions=new_req_data.mm_positions,
sampling_params=sampling_params,
generator=generator,
block_ids=new_req_data.block_ids,
num_computed_tokens=new_req_data.num_computed_tokens,
output_token_ids=[],
)
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
if self.model_config.uses_mrope:
image_grid_thw = []
video_grid_thw = []
for mm_input in self.requests[req_id].mm_inputs:
if mm_input.get("image_grid_thw") is not None:
image_grid_thw.extend(
mm_input["image_grid_thw"].tolist())
if mm_input.get("video_grid_thw") is not None:
video_grid_thw.extend(
mm_input["video_grid_thw"].tolist())
hf_config = self.model_config.hf_config
self.requests[req_id].mrope_positions, \
self.requests[req_id].mrope_position_delta = \
MRotaryEmbedding.get_input_positions_tensor(
self.requests[req_id].prompt_token_ids,
image_grid_thw=image_grid_thw,
video_grid_thw=video_grid_thw,
image_token_id=hf_config.image_token_id,
video_token_id=hf_config.video_token_id,
vision_start_token_id=hf_config.vision_start_token_id,
vision_end_token_id=hf_config.vision_end_token_id,
spatial_merge_size=hf_config.vision_config.
spatial_merge_size,
)
req_ids_to_add.append(req_id)
# Update the cached states of the resumed requests.
for res_req_data in scheduler_output.scheduled_resumed_reqs:
req_id = res_req_data.req_id
req_state = self.requests[req_id]
req_state.block_ids = res_req_data.block_ids
req_state.num_computed_tokens = res_req_data.num_computed_tokens
req_ids_to_add.append(req_id)
# Add the new or resumed requests to the persistent batch.
# The smaller empty indices are filled first.
removed_req_indices = sorted(removed_req_indices, reverse=True)
for req_id in req_ids_to_add:
req_state = self.requests[req_id]
if removed_req_indices:
# Fill the empty index.
req_index = removed_req_indices.pop()
else:
# Append to the end.
req_index = None
self.input_batch.add_request(req_state, req_index)
# Condense the batched states if there are empty indices.
if removed_req_indices:
self.input_batch.condense(removed_req_indices)
def get_model(self) -> nn.Module:
assert self.model is not None
return self.model

View File

@ -15,7 +15,6 @@ from vllm.config import VllmConfig
from vllm.forward_context import set_forward_context
from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model
from vllm.sampling_params import SamplingType
from vllm.v1.attention.backends.pallas import (PallasAttentionBackend,
PallasMetadata)
from vllm.v1.kv_cache_interface import FullAttentionSpec, KVCacheConfig
@ -40,25 +39,24 @@ _MAX_NUM_SAMPLES = 128
@dataclass
class PrefillInputData:
class PromptInputData:
request_ids: List
req_ids: List
prompt_lens: List
token_ids: List
position_ids: List
input_tokens: List
input_positions: List
attn_metadata: List
def zipped(self):
return zip(self.request_ids, self.prompt_lens, self.token_ids,
self.position_ids, self.attn_metadata)
return zip(self.req_ids, self.prompt_lens, self.input_tokens,
self.input_positions, self.attn_metadata)
@dataclass
class DecodeInputData:
num_decodes: int
token_ids: Optional[torch.Tensor] = None
position_ids: Optional[torch.Tensor] = None
req_ids: List
input_tokens: Optional[torch.Tensor] = None
input_positions: Optional[torch.Tensor] = None
attn_metadata: Optional[PallasMetadata] = None
@ -88,158 +86,105 @@ class TPUModelRunner(ModelRunnerBase):
self.kv_caches: List[Tuple[torch.Tensor, torch.Tensor]] = []
# Used to initialize positions for the individual prefills
self.prefill_positions = torch.tensor(range(self.max_model_len),
device="cpu",
dtype=torch.int32).reshape(
1, -1)
self.prefill_input_positions = torch.tensor(range(self.max_model_len),
device="cpu",
dtype=torch.int32).reshape(
1, -1)
# Used to indicate how many prefills there are for each scheduler
# iteration
self.num_new_reqs: int = 0
def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
# Remove stopped requests from the cached states.
# Keep the states of the pre-empted requests.
for req_id in scheduler_output.finished_req_ids:
self.requests.pop(req_id, None)
# Remove the requests from the persistent batch.
stopped_req_ids = set().union(
scheduler_output.preempted_req_ids,
scheduler_output.finished_req_ids,
)
removed_req_indices: List[int] = []
for req_id in stopped_req_ids:
req_index = self.input_batch.remove_request(req_id)
if req_index is not None:
removed_req_indices.append(req_index)
# Update the states of the running requests.
for req_data in scheduler_output.scheduled_running_reqs:
req_id = req_data.req_id
req_state = self.requests[req_id]
req_index = self.input_batch.req_id_to_index[req_id]
# Update the num_computed_tokens.
req_state.num_computed_tokens = req_data.num_computed_tokens
self.input_batch.num_computed_tokens_cpu[req_index] = (
req_data.num_computed_tokens)
# Update the block table.
num_new_blocks = len(req_data.new_block_ids)
if num_new_blocks == 0:
continue
start_index = len(req_state.block_ids)
req_state.block_ids.extend(req_data.new_block_ids)
self.input_batch.block_table.append_row(req_index, start_index,
req_data.new_block_ids)
req_ids_to_add: List[str] = []
# Add new requests to the cached states.
for new_req_data in scheduler_output.scheduled_new_reqs:
req_id = new_req_data.req_id
sampling_params = new_req_data.sampling_params
if sampling_params.sampling_type == SamplingType.RANDOM_SEED:
generator = torch.Generator(device=self.device)
generator.manual_seed(sampling_params.seed)
else:
generator = None
self.requests[req_id] = CachedRequestState(
req_id=req_id,
prompt_token_ids=new_req_data.prompt_token_ids,
prompt=new_req_data.prompt,
mm_inputs=new_req_data.mm_inputs,
mm_positions=new_req_data.mm_positions,
sampling_params=sampling_params,
generator=generator,
block_ids=new_req_data.block_ids,
num_computed_tokens=new_req_data.num_computed_tokens,
output_token_ids=[],
)
req_ids_to_add.append(req_id)
# Update the cached states of the resumed requests.
for res_req_data in scheduler_output.scheduled_resumed_reqs:
req_id = res_req_data.req_id
req_state = self.requests[req_id]
req_state.block_ids = res_req_data.block_ids
req_state.num_computed_tokens = res_req_data.num_computed_tokens
req_ids_to_add.append(req_id)
# For TPU, we keep all of the decode requests before the
# prefill requests in the batch sequence.
# 1. First condense, so all decodes move to start
# 2. Then add new prefills to the end of the batch
removed_req_indices = sorted(removed_req_indices, reverse=True)
if removed_req_indices:
self.input_batch.condense(removed_req_indices)
for req_id in req_ids_to_add:
req_state = self.requests[req_id]
self.input_batch.add_request(req_state, None) # Append last
self.num_new_reqs = len(req_ids_to_add)
def _prepare_prefill_inputs(
def _prepare_prompt_inputs(
self,
num_scheduled_tokens: List[int],
) -> PrefillInputData:
# Each prefill run separately with shape [1, padded_prompt_len].
# So we create lists that will be used in execute_model().
prefill_request_ids = []
prefill_prompt_lens = []
prefill_token_ids = []
prefill_position_ids = []
prefill_attn_metadata = []
# DECODES are the first num_decodes REQUESTS.
# PREFILLS are the next num_reqs - num_decodes REQUESTS.
scheduler_output: "SchedulerOutput",
) -> PromptInputData:
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
assert total_num_scheduled_tokens > 0
num_reqs = self.input_batch.num_reqs
num_decodes = num_reqs - self.num_new_reqs
for idx in range(num_decodes, num_reqs):
req_id = self.input_batch.req_ids[idx]
prefill_request_ids.append(req_id)
assert num_reqs > 0
prompt_len = num_scheduled_tokens[idx]
prefill_prompt_lens.append(prompt_len)
# OPTIMIZATION: Start copying the block table first.
# This way, we can overlap the copy with the following CPU operations.
self.input_batch.block_table.commit(num_reqs)
# STATIC SHAPE: prefills are padded to the next power of 2.
req_ids = []
prompt_lens = []
input_tokens_list = []
input_positions_list = []
attn_metadata_list = []
for req_id in self.input_batch.req_ids[:num_reqs]:
assert req_id is not None
req_index = self.input_batch.req_id_to_index[req_id]
req_state = self.requests[req_id]
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[
req_id]
num_computed_tokens = req_state.num_computed_tokens
num_prompt_tokens = len(req_state.prompt_token_ids)
# Detect whether this is a prompt (can be full or chunked)
if num_computed_tokens >= num_prompt_tokens:
# This is a decode => Skip
continue
# This is a prompt
req_ids.append(req_id)
# Prompt len
prompt_len = num_scheduled_tokens
prompt_lens.append(prompt_len)
padded_prompt_len = _get_padded_prefill_len(prompt_len)
assert padded_prompt_len <= self.max_model_len
# TOKEN_IDS.
token_ids = torch.from_numpy(self.input_batch.token_ids_cpu[
idx, :padded_prompt_len].reshape(1, -1))
token_ids[:, prompt_len:] = 0
prefill_token_ids.append(token_ids.to(self.device))
# Seq len
seq_len = num_computed_tokens + prompt_len
padded_seq_len = num_computed_tokens + padded_prompt_len
# POSITIONS.
positions = self.prefill_positions[:, :padded_prompt_len].clone()
positions[:, prompt_len:] = 0
prefill_position_ids.append(positions.to(self.device))
# Input tokens
input_tokens = torch.from_numpy(self.input_batch.token_ids_cpu[
req_index, num_computed_tokens:padded_seq_len].reshape(1, -1))
input_tokens[:, prompt_len:] = 0
input_tokens_list.append(input_tokens.to(self.device))
# SLOT_MAPPING.
# The "slot" is the "physical index" of a token in the KV cache.
# Look up the block_idx in the block table (logical<>physical map)
# to compute this.
# Input positions
input_positions = self.prefill_input_positions[:,
num_computed_tokens:
padded_seq_len].clone(
)
input_positions[:, prompt_len:] = 0
input_positions_list.append(input_positions.to(self.device))
# Slot mapping
block_table_cpu_tensor = \
self.input_batch.block_table.get_cpu_tensor()
block_numbers = block_table_cpu_tensor[idx, positions //
block_numbers = block_table_cpu_tensor[req_index,
input_positions //
self.block_size].reshape(
1, -1)
block_offsets = positions % self.block_size
block_offsets = input_positions % self.block_size
slot_mapping = block_numbers * self.block_size + block_offsets
# Set an out of range value for the padding tokens so that they
# are ignored when inserting into the KV cache.
slot_mapping[:, prompt_len:] = _PAD_SLOT_ID
slot_mapping = slot_mapping.long()
prefill_attn_metadata.append(
# Block table
block_table = None
if num_computed_tokens > 0:
block_table = self.input_batch.block_table.get_device_tensor()
block_table = block_table[req_index].unsqueeze(0)
# Context len
context_len = 0
if num_computed_tokens > 0:
context_len = seq_len
context_lens = torch.tensor([context_len],
dtype=torch.int32,
device="cpu")
# Effective query len
effective_query_lens = torch.tensor([prompt_len],
dtype=torch.int32,
device="cpu")
# Attn metadata
attn_metadata_list.append(
PallasMetadata(
num_prefills=1,
num_prefill_tokens=0, # NOTE: This is not used.
@ -247,208 +192,200 @@ class TPUModelRunner(ModelRunnerBase):
slot_mapping=slot_mapping.to(self.device),
multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=True,
block_tables=None,
context_lens=None,
effective_query_lens=None,
block_tables=block_table,
context_lens=context_lens.to(self.device),
effective_query_lens=effective_query_lens.to(self.device),
))
return PrefillInputData(
request_ids=prefill_request_ids,
prompt_lens=prefill_prompt_lens,
token_ids=prefill_token_ids,
position_ids=prefill_position_ids,
attn_metadata=prefill_attn_metadata,
return PromptInputData(
req_ids=req_ids,
prompt_lens=prompt_lens,
input_tokens=input_tokens_list,
input_positions=input_positions_list,
attn_metadata=attn_metadata_list,
)
def _prepare_decode_inputs(self) -> DecodeInputData:
# Decodes run as one single padded batch with shape [batch, 1]
#
# We need to set _PAD_SLOT_ID for the padding tokens in the
# slot_mapping, such that the attention KV cache insertion
# logic knows to ignore those indices. Otherwise, the
# padding data can be dummy since we have a causal mask.
# DECODES are the first num_decodes REQUESTS.
# PREFILLS are the next num_reqs - num_decodes REQUESTS.
num_reqs = self.input_batch.num_reqs
num_decodes = num_reqs - self.num_new_reqs
if num_decodes == 0:
return DecodeInputData(num_decodes=0)
# PAD FOR STATIC SHAPES.
padded_batch_size = _get_padded_batch_size(num_decodes)
# POSITIONS. [batch, 1]
# We slice at the end, since we use the positions for gathering.
positions = torch.from_numpy(
self.input_batch.num_computed_tokens_cpu.reshape(-1, 1))
index = positions.to(torch.int64)
index[num_decodes:] = 0
positions = positions[:padded_batch_size]
positions[num_decodes:] = 0
# TOKEN_IDS. [batch, 1]
token_ids = torch.gather(
input=torch.from_numpy(self.input_batch.token_ids_cpu),
dim=1,
index=index,
)[:padded_batch_size].to(torch.int32)
token_ids[num_decodes:] = 0
# SLOT_MAPPING [batch, 1]
# The "slot" is the "physical index" of a token in the KV cache.
# Look up the block_idx in the block table (logical<>physical map)
# to compute this.
block_table_cpu_tensor = self.input_batch.block_table.get_cpu_tensor()
block_number = torch.gather(input=block_table_cpu_tensor,
dim=1,
index=(index // self.block_size))
block_offsets = index % self.block_size
slot_mapping = block_number * self.block_size + block_offsets
# Set an out of range value for the padding tokens so that they
# are ignored when inserting into the KV cache.
slot_mapping[num_decodes:] = _PAD_SLOT_ID
slot_mapping = slot_mapping[:padded_batch_size]
slot_mapping = slot_mapping.long()
# BLOCK_TABLE [batch, max_num_blocks_per_req]
block_table = block_table_cpu_tensor[:padded_batch_size]
# CONTEXT_LENS [batch_size]
context_lens = (positions.reshape(-1) + 1)
context_lens[num_decodes:] = 0
# CPU<>TPU sync happens here.
return DecodeInputData(num_decodes=num_decodes,
token_ids=token_ids.to(self.device),
position_ids=positions.to(self.device),
attn_metadata=PallasMetadata(
num_prefills=0,
num_prefill_tokens=0,
num_decode_tokens=padded_batch_size,
slot_mapping=slot_mapping.to(self.device),
multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=True,
block_tables=block_table.to(self.device),
context_lens=context_lens.to(self.device),
effective_query_lens=None,
))
def _prepare_inputs(self, scheduler_output: "SchedulerOutput"):
def _prepare_decode_inputs(
self,
scheduler_output: "SchedulerOutput",
) -> DecodeInputData:
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
assert total_num_scheduled_tokens > 0
num_reqs = self.input_batch.num_reqs
assert num_reqs > 0
num_decodes = num_reqs - self.num_new_reqs
block_table_cpu_tensor = self.input_batch.block_table.get_cpu_tensor()
# TODO: Resurrect
# OPTIMIZATION: Start copying the block table first.
# This way, we can overlap the copy with the following CPU operations.
# TODO: Verify this works with TPUs
# self.input_batch.block_table.commit(num_reqs)
# Get the number of scheduled tokens for each request.
# TODO: The Python loop can be slow. Optimize.
num_scheduled_tokens = []
max_num_scheduled_tokens = 0
for idx, req_id in enumerate(self.input_batch.req_ids[:num_reqs]):
req_ids = []
req_indices = []
input_tokens = []
input_positions = []
slot_mapping = []
context_lens = []
for req_id in self.input_batch.req_ids[:num_reqs]:
assert req_id is not None
num_tokens = scheduler_output.num_scheduled_tokens[req_id]
num_scheduled_tokens.append(num_tokens)
max_num_scheduled_tokens = max(max_num_scheduled_tokens,
num_tokens)
req_index = self.input_batch.req_id_to_index[req_id]
req_state = self.requests[req_id]
# NOTE: Assert that all the decodes are "decodes".
if idx < num_decodes:
assert num_tokens == 1
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[
req_id]
num_computed_tokens = req_state.num_computed_tokens
num_prompt_tokens = len(req_state.prompt_token_ids)
assert max_num_scheduled_tokens > 0
# Detect whether this is a decode
if num_computed_tokens < num_prompt_tokens:
# This is a prompt => Skip
continue
return (
self._prepare_prefill_inputs(num_scheduled_tokens),
self._prepare_decode_inputs(),
# This is a decode
req_ids.append(req_id)
req_indices.append(req_index)
# Seq len
seq_len = num_computed_tokens + num_scheduled_tokens
# Sanity check decode
assert num_scheduled_tokens == 1
assert seq_len == req_state.num_tokens
# Input token
input_tokens.append([
self.input_batch.token_ids_cpu[req_index, num_computed_tokens]
])
# Position
input_positions.append([num_computed_tokens])
# Slot mapping
block_number = block_table_cpu_tensor[req_index,
num_computed_tokens //
self.block_size]
block_offset = num_computed_tokens % self.block_size
slot_id = block_number * self.block_size + block_offset
slot_mapping.append([slot_id])
# Context len
context_lens.append(seq_len)
# Compute padding
batch_size = len(input_tokens)
padded_batch_size = _get_padded_batch_size(batch_size)
num_padding = padded_batch_size - batch_size
# Add padding
input_tokens.extend([[0]] * num_padding)
input_positions.extend([[0]] * num_padding)
slot_mapping.extend([[_PAD_SLOT_ID]] * num_padding)
context_lens.extend([0] * num_padding)
req_indices.extend([0] * num_padding)
# Create tensors
input_tokens_tensor = torch.tensor(input_tokens,
dtype=torch.int32,
device="cpu")
input_positions_tensor = torch.tensor(input_positions,
dtype=torch.int32,
device="cpu")
slot_mapping_tensor = torch.tensor(slot_mapping,
dtype=torch.int64,
device="cpu")
context_lens_tensor = torch.tensor(context_lens,
dtype=torch.int32,
device="cpu")
block_tables_tensor = block_table_cpu_tensor[req_indices]
# Attn metadata
attn_metadata = PallasMetadata(
num_prefills=0,
num_prefill_tokens=0,
num_decode_tokens=padded_batch_size,
slot_mapping=slot_mapping_tensor.to(self.device),
multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=True,
block_tables=block_tables_tensor.to(self.device),
context_lens=context_lens_tensor.to(self.device),
effective_query_lens=None,
)
return DecodeInputData(
req_ids=req_ids,
input_tokens=input_tokens_tensor.to(self.device),
input_positions=input_positions_tensor.to(self.device),
attn_metadata=attn_metadata)
@torch.no_grad()
def execute_model(
self,
scheduler_output: "SchedulerOutput",
) -> ModelRunnerOutput:
self._update_states(scheduler_output)
# Update cached state
self.update_states(scheduler_output)
# Prepare the decoder inputs.
prefill_data, decode_data = self._prepare_inputs(scheduler_output)
# Prepare inputs
prompt_data = self._prepare_prompt_inputs(scheduler_output)
decode_data = self._prepare_decode_inputs(scheduler_output)
# Init
num_reqs = self.input_batch.num_reqs
assert num_reqs > 0
sampled_token_ids_list = [0] * num_reqs
######################### DECODES #########################
# Decodes run as one single batch with [padded_batch, 1]
sampled_token_ids_list = []
if decode_data.num_decodes > 0:
# FORWARD.
# Run decodes (a single batch)
if len(decode_data.req_ids) > 0:
# Forward
with set_forward_context(decode_data.attn_metadata,
self.vllm_config):
assert self.model is not None
selected_token_ids = self.model(decode_data.token_ids,
decode_data.position_ids,
selected_token_ids = self.model(decode_data.input_tokens,
decode_data.input_positions,
decode_data.attn_metadata,
self.kv_caches)
# NOTE: TPU<>CPU sync happens here.
# We need to call .cpu() first to avoid recompilation.
token_ids = selected_token_ids.cpu()[:decode_data.num_decodes]
sampled_token_ids_list.extend(token_ids.tolist())
# Transfer sampled tokens from TPU to CPU
selected_token_ids_list = selected_token_ids.cpu().tolist()
# UPDATE REQUEST STATE.
for i, req_id in enumerate(
self.input_batch.req_ids[:decode_data.num_decodes]):
assert req_id is not None
# Update cached state
for i, req_id in enumerate(decode_data.req_ids):
req_index = self.input_batch.req_id_to_index[req_id]
req_state = self.requests[req_id]
assert scheduler_output.num_scheduled_tokens[req_id] == 1
seq_len = (req_state.num_computed_tokens +
scheduler_output.num_scheduled_tokens[req_id])
assert seq_len == req_state.num_tokens
token_id = sampled_token_ids_list[i]
self.input_batch.token_ids_cpu[i, seq_len] = token_id
self.input_batch.num_tokens[i] += 1
token_id = selected_token_ids_list[i]
self.input_batch.token_ids_cpu[req_index, seq_len] = token_id
self.input_batch.num_tokens[req_index] += 1
req_state.output_token_ids.append(token_id)
######################### PREFILLS #########################
# Prefills run separately with shape [1, padded_prefill_len],
# due to lack of variable length attention kernel so far.
for (req_id, prompt_len, token_ids, position_ids,
attn_metadata) in prefill_data.zipped():
assert req_id is not None
sampled_token_ids_list[req_index] = token_id
# FORWARD.
# Run each prompt
for (req_id, prompt_len, input_tokens, input_positions,
attn_metadata) in prompt_data.zipped():
assert req_id is not None
req_state = self.requests[req_id]
req_index = self.input_batch.req_id_to_index[req_id]
# Forward
with set_forward_context(attn_metadata, self.vllm_config):
assert self.model is not None
selected_token_ids = self.model(token_ids, position_ids,
selected_token_ids = self.model(input_tokens, input_positions,
attn_metadata, self.kv_caches)
# NOTE: TPU<>CPU sync happens here.
# We need to call .cpu() first to avoid recompilation.
token_id = selected_token_ids.cpu()[prompt_len - 1].item()
sampled_token_ids_list.append(token_id)
req_state = self.requests[req_id]
assert req_state.num_computed_tokens == 0
seq_len = (req_state.num_computed_tokens +
scheduler_output.num_scheduled_tokens[req_id])
assert seq_len == req_state.num_tokens
assert prompt_len == seq_len
if seq_len >= len(req_state.prompt_token_ids):
# Transfer sampled tokens from TPU to CPU
token_id = selected_token_ids.cpu()[prompt_len - 1].item()
sampled_token_ids_list[req_index] = token_id
# UPDATE REQUEST STATE.
req_idx = self.input_batch.req_id_to_index[req_id]
self.input_batch.token_ids_cpu[req_idx, seq_len] = token_id
self.input_batch.num_tokens[req_idx] += 1
req_state.output_token_ids.append(token_id)
# Update cached state
self.input_batch.token_ids_cpu[req_index, seq_len] = token_id
self.input_batch.num_tokens[req_index] += 1
req_state.output_token_ids.append(token_id)
# num_reqs entries should be non-None
# Get req_ids
assert all(
req_id is not None for req_id in
self.input_batch.req_ids[:num_reqs]), "req_ids contains None"