# SPDX-License-Identifier: Apache-2.0 import enum import time from dataclasses import dataclass from typing import TYPE_CHECKING, Dict, List, Optional, Tuple from unittest.mock import patch import numpy as np import torch import torch.distributed import torch.nn as nn # TPU XLA related import torch_xla.core.xla_model as xm import torch_xla.runtime as xr from vllm.attention import AttentionMetadata from vllm.attention.backends.abstract import AttentionType from vllm.attention.layer import Attention from vllm.config import VllmConfig from vllm.forward_context import set_forward_context from vllm.logger import init_logger from vllm.model_executor.model_loader import get_model from vllm.sampling_params import SamplingType from vllm.utils import LayerBlockType, cdiv, is_pin_memory_available from vllm.v1.attention.backends.pallas import (PallasAttentionBackend, PallasMetadata) from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, KVCacheSpec) from vllm.v1.outputs import LogprobsTensors, ModelRunnerOutput from vllm.v1.utils import bind_kv_cache from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch if TYPE_CHECKING: from vllm.v1.core.scheduler import SchedulerOutput logger = init_logger(__name__) # Here we utilize the behavior that out-of-bound index is ignored. # FIXME(woosuk): Find a more reliable way to prevent possible bugs. _PAD_SLOT_ID = 1_000_000_000 class ExecutionMode(enum.Enum): PREFILL = enum.auto() DECODE = enum.auto() PREFIX_PREFILL = enum.auto() def is_prefill(self) -> bool: return self in (ExecutionMode.PREFILL, ExecutionMode.PREFIX_PREFILL) @dataclass class PromptDecodeInfo: prompt_req_ids: List[str] decode_req_ids: List[str] prompt_scheduled_tokens: List[int] @dataclass class PromptData: input_tokens: torch.Tensor input_positions: torch.Tensor attn_metadata: PallasMetadata @dataclass class DecodeData: input_tokens: Optional[torch.Tensor] = None input_positions: Optional[torch.Tensor] = None attn_metadata: Optional[PallasMetadata] = None class TPUModelRunner: def __init__( self, vllm_config: VllmConfig, device: torch.device, ): self.vllm_config = vllm_config self.model_config = vllm_config.model_config self.cache_config = vllm_config.cache_config self.lora_config = vllm_config.lora_config self.load_config = vllm_config.load_config self.parallel_config = vllm_config.parallel_config self.scheduler_config = vllm_config.scheduler_config self.speculative_config = vllm_config.speculative_config self.prompt_adapter_config = vllm_config.prompt_adapter_config self.observability_config = vllm_config.observability_config self.device_config = vllm_config.device_config model_config = self.model_config cache_config = self.cache_config scheduler_config = self.scheduler_config parallel_config = self.parallel_config self.device = device self.pin_memory = is_pin_memory_available() self.dtype = self.model_config.dtype self.is_multimodal_model = model_config.is_multimodal_model self.sliding_window = model_config.get_sliding_window() self.block_size = cache_config.block_size self.max_model_len = model_config.max_model_len self.max_num_blocks_per_req = cdiv(self.max_model_len, self.block_size) self.max_num_tokens = scheduler_config.max_num_batched_tokens self.max_num_reqs = scheduler_config.max_num_seqs # Model-related. self.num_attn_layers = model_config.get_num_layers_by_block_type( parallel_config, LayerBlockType.attention) self.num_query_heads = model_config.get_num_attention_heads( parallel_config) self.num_kv_heads = model_config.get_num_kv_heads(parallel_config) self.head_size = model_config.get_head_size() self.hidden_size = model_config.get_hidden_size() self.model: Optional[nn.Module] = None # Persistent batch. self.input_batch = InputBatch( max_num_reqs=self.max_num_reqs, max_model_len=self.max_model_len, max_num_blocks_per_req=self.max_num_blocks_per_req, device=self.device, pin_memory=self.pin_memory, vocab_size=self.model_config.get_vocab_size(), ) # Request states. self.requests: Dict[str, CachedRequestState] = {} # req_id -> (input_id -> encoder_output) self.encoder_cache: Dict[str, Dict[int, torch.Tensor]] = {} # KV caches for forward pass self.kv_caches: List[Tuple[torch.Tensor, torch.Tensor]] = [] # Cached torch/numpy tensors self.num_swaps = 2 self.cur_swap_id = 0 self.input_ids_cpu = [] self.input_ids_np = [] self.input_positions_cpu = [] self.input_positions_np = [] self.slot_mapping_cpu = [] self.slot_mapping_np = [] self.prompt_context_lens_cpu = [] self.prompt_effective_query_lens_cpu = [] self.decode_context_lens_cpu = [] self.decode_context_lens_np = [] for _ in range(self.num_swaps): self.input_ids_cpu.append( torch.empty(self.max_num_tokens, dtype=torch.int32, device="cpu")) self.input_ids_np.append(self.input_ids_cpu[-1].numpy()) self.input_positions_cpu.append( torch.empty(self.max_num_tokens, dtype=torch.int32, device="cpu")) self.input_positions_np.append( self.input_positions_cpu[-1].numpy()) self.slot_mapping_cpu.append( torch.empty(self.max_num_tokens, dtype=torch.int64, device="cpu")) self.slot_mapping_np.append(self.slot_mapping_cpu[-1].numpy()) self.prompt_context_lens_cpu.append( torch.empty((1), dtype=torch.int32, device="cpu")) self.prompt_effective_query_lens_cpu.append( torch.empty((1), dtype=torch.int32, device="cpu")) self.decode_context_lens_cpu.append( torch.empty(self.max_num_tokens, dtype=torch.int32, device="cpu")) self.decode_context_lens_np.append( self.decode_context_lens_cpu[-1].numpy()) # Range tensor with values [0 .. self.max_num_tokens - 1]. # Used to initialize positions / context_lens / seq_lens self.arange_np = np.arange(self.max_num_tokens, dtype=np.int32) def _update_states(self, scheduler_output: "SchedulerOutput") -> bool: """Update the cached states and the persistent batch with the scheduler output. The updated states are used by the `_prepare_inputs` function to create the input GPU tensors for the model. Returns: True if there is a new/resumed/paused/finished request in the batch. If False, we can skip copying SamplingMetadata to the GPU. """ # Remove finished requests from the cached states. for req_id in scheduler_output.finished_req_ids: self.requests.pop(req_id, None) # Remove the finished requests from the persistent batch. # NOTE(woosuk): There could be an edge case where finished_req_ids and # scheduled_req_ids overlap. This happens when a request is aborted and # then resubmitted with the same ID. In this case, we treat them as two # distinct requests - clearing the cached states for the first request # and handling the second as a new request. removed_req_indices: List[int] = [] for req_id in scheduler_output.finished_req_ids: req_index = self.input_batch.remove_request(req_id) if req_index is not None: removed_req_indices.append(req_index) # Remove the unscheduled requests from the persistent batch. # NOTE(woosuk): The unscheduled requests are either preempted requests # or running requests that are not scheduled in this step. We remove # them from the persistent batch but keep their cached states since # they will be scheduled again sometime in the future. scheduled_req_ids = scheduler_output.num_scheduled_tokens.keys() cached_req_ids = self.input_batch.req_id_to_index.keys() unscheduled_req_ids = cached_req_ids - scheduled_req_ids # NOTE(woosuk): The persistent batch optimization assumes that # consecutive batches contain mostly the same requests. If batches # have low request overlap (e.g., alternating between two distinct # sets of requests), this optimization becomes very inefficient. for req_id in unscheduled_req_ids: req_index = self.input_batch.remove_request(req_id) assert req_index is not None removed_req_indices.append(req_index) req_ids_to_add: List[str] = [] # Add new requests to the cached states. for new_req_data in scheduler_output.scheduled_new_reqs: req_id = new_req_data.req_id sampling_params = new_req_data.sampling_params if sampling_params.sampling_type == SamplingType.RANDOM_SEED: generator = torch.Generator(device=self.device) generator.manual_seed(sampling_params.seed) else: generator = None self.requests[req_id] = CachedRequestState( req_id=req_id, prompt_token_ids=new_req_data.prompt_token_ids, prompt=new_req_data.prompt, mm_inputs=new_req_data.mm_inputs, mm_positions=new_req_data.mm_positions, sampling_params=sampling_params, generator=generator, block_ids=new_req_data.block_ids, num_computed_tokens=new_req_data.num_computed_tokens, output_token_ids=[], lora_request=new_req_data.lora_request, ) req_ids_to_add.append(req_id) # Update the states of the running/resumed requests. for req_data in scheduler_output.scheduled_cached_reqs: req_id = req_data.req_id req_state = self.requests[req_id] # Update the cached states. req_state.num_computed_tokens = req_data.num_computed_tokens if not req_data.resumed_from_preemption: # Append the new blocks to the existing block IDs. req_state.block_ids.extend(req_data.new_block_ids) else: # The request is resumed from preemption. # Replace the existing block IDs with the new ones. req_state.block_ids = req_data.new_block_ids req_index = self.input_batch.req_id_to_index.get(req_id) if req_index is None: # The request is not in the persistent batch. # The request was either preempted and resumed later, or was not # scheduled in the previous step and needs to be added again. req_ids_to_add.append(req_id) continue # Update the persistent batch. self.input_batch.num_computed_tokens_cpu[req_index] = ( req_data.num_computed_tokens) start_index = len(req_state.block_ids) - len( req_data.new_block_ids) self.input_batch.block_table.append_row(req_index, start_index, req_data.new_block_ids) # Add the new or resumed requests to the persistent batch. # The smaller empty indices are filled first. removed_req_indices = sorted(removed_req_indices, reverse=True) for req_id in req_ids_to_add: req_state = self.requests[req_id] if removed_req_indices: # Fill the empty index. req_index = removed_req_indices.pop() else: # Append to the end. req_index = None self.input_batch.add_request(req_state, req_index) # Condense the batched states if there are empty indices. if removed_req_indices: self.input_batch.condense(removed_req_indices) return len(unscheduled_req_ids) > 0 or len(req_ids_to_add) > 0 def swap_step(self): self.cur_swap_id = (self.cur_swap_id + 1) % self.num_swaps def get_model(self) -> nn.Module: assert self.model is not None return self.model def get_kv_cache_spec(self) -> KVCacheSpec: """ Generates the KVCacheSpec by parsing the kv cache format from each Attention module in the static forward context. Returns: KVCacheSpec: A dictionary mapping layer names to their KV cache format. Layers that do not need KV cache are not included. """ forward_ctx = self.vllm_config.compilation_config.static_forward_context block_size = self.vllm_config.cache_config.block_size kv_cache_spec: KVCacheSpec = {} for layer_name, attn_module in forward_ctx.items(): # TODO: Support other attention modules, e.g., sliding window, # cross-attention, MLA. assert isinstance(attn_module, Attention) if attn_module.attn_type == AttentionType.DECODER: kv_cache_spec[layer_name] = FullAttentionSpec( block_size=block_size, num_kv_heads=attn_module.num_kv_heads, head_size=attn_module.head_size, dtype=attn_module.dtype, ) elif attn_module.attn_type in (AttentionType.ENCODER, AttentionType.ENCODER_ONLY): # encoder-only attention does not need KV cache. continue elif attn_module.attn_type == AttentionType.ENCODER_DECODER: raise NotImplementedError else: raise ValueError( f"Unknown attention type: {attn_module.attn_type}") return kv_cache_spec def _get_prompts_and_decodes( self, scheduler_output: "SchedulerOutput", ) -> PromptDecodeInfo: total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens assert total_num_scheduled_tokens > 0 num_reqs = self.input_batch.num_reqs assert num_reqs > 0 # Traverse decodes first decode_req_ids = [] for i in range(num_reqs): req_id = self.input_batch.req_ids[i] assert req_id is not None num_computed_tokens = self.input_batch.num_computed_tokens_cpu[i] num_prompt_tokens = self.input_batch.num_prompt_tokens[i] num_scheduled_tokens = scheduler_output.num_scheduled_tokens[ req_id] if num_computed_tokens < num_prompt_tokens: # This is prompt break # This is decode assert num_scheduled_tokens == 1 decode_req_ids.append(req_id) # Traverse prompts prompt_req_ids = [] prompt_scheduled_tokens = [] for i in range(len(decode_req_ids), num_reqs): req_id = self.input_batch.req_ids[i] assert req_id is not None num_computed_tokens = self.input_batch.num_computed_tokens_cpu[i] num_prompt_tokens = self.input_batch.num_prompt_tokens[i] num_scheduled_tokens = scheduler_output.num_scheduled_tokens[ req_id] # Must be prompt assert num_computed_tokens < num_prompt_tokens prompt_req_ids.append(req_id) prompt_scheduled_tokens.append(num_scheduled_tokens) return PromptDecodeInfo(prompt_req_ids, decode_req_ids, prompt_scheduled_tokens) def _prepare_prompt(self, req_index: int, num_scheduled_tokens: int) -> PromptData: num_computed_tokens = self.input_batch.num_computed_tokens_cpu[ req_index] num_prompt_tokens = self.input_batch.num_prompt_tokens[req_index] # Must be prompt assert num_computed_tokens < num_prompt_tokens # Prompt len prompt_len = num_scheduled_tokens padded_prompt_len = _get_padded_prompt_len(prompt_len) assert padded_prompt_len <= self.max_model_len # Seq len seq_len = num_computed_tokens + prompt_len padded_seq_len = num_computed_tokens + padded_prompt_len # Input tokens input_tokens_cpu = self.input_batch.token_ids_cpu_tensor[ req_index, num_computed_tokens:padded_seq_len] input_tokens_cpu[prompt_len:] = 0 # Input positions input_positions_np = self.input_positions_np[ self.cur_swap_id][:padded_prompt_len] np.add(num_computed_tokens, self.arange_np[:padded_prompt_len], out=input_positions_np) input_positions_np[prompt_len:] = 0 # Slot mapping block_table_np = \ self.input_batch.block_table.get_numpy_array() block_numbers_np = block_table_np[req_index, input_positions_np // self.block_size] block_offsets_np = input_positions_np % self.block_size slot_mapping_np = self.slot_mapping_np[ self.cur_swap_id][:padded_prompt_len] np.add(block_numbers_np * self.block_size, block_offsets_np, out=slot_mapping_np) slot_mapping_np[prompt_len:] = _PAD_SLOT_ID # Block table block_table_cpu = None if num_computed_tokens > 0: block_table_cpu = self.input_batch.block_table.get_cpu_tensor() block_table_cpu = block_table_cpu[req_index] # Context len self.prompt_context_lens_cpu[self.cur_swap_id][0] = 0 if num_computed_tokens > 0: self.prompt_context_lens_cpu[self.cur_swap_id][0] = seq_len # Effective query len self.prompt_effective_query_lens_cpu[self.cur_swap_id][0] = prompt_len # Get final tensors input_tokens = input_tokens_cpu.reshape(1, -1).to(self.device) input_positions = self.input_positions_cpu[ self.cur_swap_id][:padded_prompt_len].reshape(1, -1).to(self.device) slot_mapping = self.slot_mapping_cpu[ self.cur_swap_id][:padded_prompt_len].reshape(1, -1).to(self.device) block_table = block_table_cpu.reshape(1, -1).to( self.device) if block_table_cpu is not None else None context_lens = self.prompt_context_lens_cpu[self.cur_swap_id].to( self.device) effective_query_lens = self.prompt_effective_query_lens_cpu[ self.cur_swap_id].to(self.device) self.swap_step() # Attn metadata attn_metadata = PallasMetadata( num_prefills=1, num_prefill_tokens=0, # NOTE: This is not used. num_decode_tokens=0, slot_mapping=slot_mapping, multi_modal_placeholder_index_maps=None, enable_kv_scales_calculation=True, block_tables=block_table, context_lens=context_lens, effective_query_lens=effective_query_lens, ) return PromptData(input_tokens, input_positions, attn_metadata) def _prepare_decode( self, decode_req_ids: List[str], ) -> DecodeData: # Batch size batch_size = len(decode_req_ids) padded_batch_size = _get_padded_batch_size(batch_size) assert padded_batch_size <= self.max_model_len # Init [0 .. batch_size - 1] req_indices_np = self.arange_np[:padded_batch_size] # Input positions input_positions_np = self.input_positions_np[ self.cur_swap_id][:padded_batch_size] np.add(self.input_batch.num_computed_tokens_cpu[:padded_batch_size], 0, out=input_positions_np) input_positions_np[batch_size:] = 0 input_positions_cpu = self.input_positions_cpu[ self.cur_swap_id][:padded_batch_size] # Input tokens token_indices_np = ( input_positions_np + req_indices_np * self.input_batch.token_ids_cpu.shape[1]) input_tokens_cpu = self.input_ids_cpu[ self.cur_swap_id][:padded_batch_size] torch.index_select(self.input_batch.token_ids_cpu_tensor.flatten(), 0, torch.from_numpy(token_indices_np), out=input_tokens_cpu) input_tokens_cpu[batch_size:] = 0 # Slot mapping block_table_indices_np = ( req_indices_np * self.max_num_blocks_per_req + input_positions_np // self.block_size) block_table_cpu = self.input_batch.block_table.get_cpu_tensor() block_numbers_np = block_table_cpu.flatten( )[block_table_indices_np].numpy() block_offsets_np = input_positions_np % self.block_size slot_mapping_np = self.slot_mapping_np[ self.cur_swap_id][:padded_batch_size] np.add(block_numbers_np * self.block_size, block_offsets_np, out=slot_mapping_np) slot_mapping_np[batch_size:] = _PAD_SLOT_ID block_table_cpu = block_table_cpu[:padded_batch_size] # Context lens context_lens_np = self.decode_context_lens_np[ self.cur_swap_id][:padded_batch_size] np.add(self.input_batch.num_computed_tokens_cpu[:padded_batch_size], 1, out=context_lens_np) context_lens_np[batch_size:] = 0 # Get final tensors input_tokens = input_tokens_cpu.reshape(-1, 1).to(self.device) input_positions = input_positions_cpu.reshape(-1, 1).to(self.device) slot_mapping = self.slot_mapping_cpu[ self.cur_swap_id][:padded_batch_size].reshape(-1, 1).to(self.device) block_table = block_table_cpu.to(self.device) context_lens = self.decode_context_lens_cpu[ self.cur_swap_id][:padded_batch_size].to(self.device) self.swap_step() # Attn metadata attn_metadata = PallasMetadata( num_prefills=0, num_prefill_tokens=0, num_decode_tokens=padded_batch_size, slot_mapping=slot_mapping, multi_modal_placeholder_index_maps=None, enable_kv_scales_calculation=True, block_tables=block_table, context_lens=context_lens, effective_query_lens=None, ) return DecodeData(input_tokens=input_tokens, input_positions=input_positions, attn_metadata=attn_metadata) @torch.no_grad() def execute_model( self, scheduler_output: "SchedulerOutput", ) -> ModelRunnerOutput: # Update cached state self._update_states(scheduler_output) # If necessary, swap decodes/prompts to have all decodes on the start ensure_decodes_first(self.input_batch) # Prepare prompts/decodes info pd_info = self._get_prompts_and_decodes(scheduler_output) # Init num_prompts = len(pd_info.prompt_req_ids) num_decodes = len(pd_info.decode_req_ids) decode_data = None sampled_token_ids = [0] * self.input_batch.num_reqs # Run each prompt individually is_first = True for i in range(num_prompts): req_id = pd_info.prompt_req_ids[i] req_index = num_decodes + i assert req_index == self.input_batch.req_id_to_index[ req_id] # TODO: Remove req_state = self.requests[req_id] num_scheduled_tokens = pd_info.prompt_scheduled_tokens[i] prompt_len = num_scheduled_tokens seq_len = req_state.num_computed_tokens + num_scheduled_tokens # Prepare first prompt if is_first: prompt_data = self._prepare_prompt(req_index, num_scheduled_tokens) is_first = False # Run forward pass with set_forward_context(prompt_data.attn_metadata, self.vllm_config): assert self.model is not None selected_token_ids = self.model(prompt_data.input_tokens, prompt_data.input_positions, prompt_data.attn_metadata, self.kv_caches) # In parallel to TPU execution, prepare the next iteration if i < num_prompts - 1: # There is next prompt => prepare it prompt_data = self._prepare_prompt( req_index + 1, pd_info.prompt_scheduled_tokens[i + 1]) elif i == num_prompts - 1 and num_decodes > 0: # There is next decode => prepare it decode_data = self._prepare_decode(pd_info.decode_req_ids) # Update cached state (if prompt is fully done) if seq_len >= len(req_state.prompt_token_ids): # Transfer sampled tokens from TPU to CPU selected_token_ids_cpu = selected_token_ids.cpu() # Get output token token_id = selected_token_ids_cpu[prompt_len - 1].item() sampled_token_ids[req_index] = token_id # Add output token to the request self.input_batch.token_ids_cpu[req_index, seq_len] = token_id self.input_batch.num_tokens[req_index] += 1 req_state.output_token_ids.append(token_id) # Run decodes (a single batch) if num_decodes > 0: # Prepare decode (if was not yet prepared) if decode_data is None: decode_data = self._prepare_decode(pd_info.decode_req_ids) # Run forward pass with set_forward_context(decode_data.attn_metadata, self.vllm_config): assert self.model is not None selected_token_ids = self.model(decode_data.input_tokens, decode_data.input_positions, decode_data.attn_metadata, self.kv_caches) # Transfer sampled tokens from TPU to CPU decode_token_ids_cpu = selected_token_ids.cpu() # Convert to list decode_token_ids_list = decode_token_ids_cpu.tolist() # Update cached state for each decode request for i in range(num_decodes): req_id = pd_info.decode_req_ids[i] req_index = i assert req_index == self.input_batch.req_id_to_index[ req_id] # TODO: Remove req_state = self.requests[req_id] seq_len = req_state.num_computed_tokens + 1 token_id = decode_token_ids_list[i] sampled_token_ids[req_index] = token_id self.input_batch.token_ids_cpu[req_index, seq_len] = token_id self.input_batch.num_tokens[req_index] += 1 req_state.output_token_ids.append(token_id) # Create output. all_req_ids = pd_info.decode_req_ids + pd_info.prompt_req_ids prompt_logprobs_dict: Dict[str, Optional[LogprobsTensors]] = {} for req_id in all_req_ids: prompt_logprobs_dict[req_id] = None model_runner_output = ModelRunnerOutput( req_ids=all_req_ids, req_id_to_index=self.input_batch.req_id_to_index, sampled_token_ids=[[token_id] for token_id in sampled_token_ids], spec_token_ids=None, logprobs=None, prompt_logprobs_dict=prompt_logprobs_dict, # type: ignore[arg-type] ) return model_runner_output def load_model(self) -> None: self.device = self.device_config.device # NOTE(woosuk): While the executor assigns the TP ranks to the worker # process, the ranks can be different from the ranks internally assigned # by the xm runtime. Therefore, there is a mismatch in the rank # assignment between the gloo (cpu) runtime and the xm (tpu) runtime. # This is not a problem in linear layers because all-reduce is # rank-agnostic. However, it matters for all-gather as the ranks # determine the order of concatenating the output tensors. # As a workaround, we use the xm's rank assignment only when loading # the embedding weights. xm_tp_rank = xr.global_ordinal() with patch( "vllm.model_executor.layers.vocab_parallel_embedding." "get_tensor_model_parallel_rank", return_value=xm_tp_rank): model = get_model(vllm_config=self.vllm_config) model = model.eval() xm.mark_step() xm.wait_device_ops() model = ModelWrapperV1(model) self.model = torch.compile(model, backend="openxla", fullgraph=True, dynamic=False) def dummy_run( self, kv_caches, num_tokens: int, seq_len: Optional[int] = None, exec_mode: Optional[ExecutionMode] = None, ) -> None: assert seq_len is not None assert exec_mode is not None exec_mode = ExecutionMode(exec_mode) if exec_mode.is_prefill(): seq_len = (seq_len + 15) // 16 * 16 token_ids = torch.zeros((num_tokens, seq_len), dtype=torch.int32, device=self.device) position_ids = torch.zeros((num_tokens, seq_len), dtype=torch.int32, device=self.device) slot_mapping = torch.zeros((num_tokens, seq_len), dtype=torch.int64, device=self.device) if exec_mode == ExecutionMode.PREFILL: attn_metadata = PallasMetadata( num_prefills=num_tokens, num_prefill_tokens=num_tokens * seq_len, num_decode_tokens=0, slot_mapping=slot_mapping, multi_modal_placeholder_index_maps=None, enable_kv_scales_calculation=True, block_tables=None, context_lens=None, effective_query_lens=None, ) else: context_lens = torch.ones((num_tokens, ), dtype=torch.int32, device=self.device) block_tables = torch.zeros( (num_tokens, self.max_num_blocks_per_req), dtype=torch.int32, device=self.device) effective_query_lens = torch.ones_like(context_lens) attn_metadata = PallasMetadata( num_prefills=num_tokens, num_prefill_tokens=num_tokens * seq_len, num_decode_tokens=0, slot_mapping=slot_mapping, multi_modal_placeholder_index_maps=None, enable_kv_scales_calculation=True, block_tables=block_tables, context_lens=context_lens, effective_query_lens=effective_query_lens, ) else: assert seq_len == 1 token_ids = torch.zeros((num_tokens, seq_len), dtype=torch.int32, device=self.device) position_ids = torch.zeros((num_tokens, seq_len), dtype=torch.int32, device=self.device) slot_mapping = torch.zeros((num_tokens, seq_len), dtype=torch.int64, device=self.device) block_tables = torch.zeros( (num_tokens, self.max_num_blocks_per_req), dtype=torch.int32, device=self.device) context_lens = torch.ones((num_tokens, ), dtype=torch.int32, device=self.device) attn_metadata = PallasMetadata( num_prefills=0, num_prefill_tokens=0, num_decode_tokens=num_tokens * seq_len, slot_mapping=slot_mapping, multi_modal_placeholder_index_maps=None, enable_kv_scales_calculation=True, block_tables=block_tables, context_lens=context_lens, ) # NOTE(woosuk): There are two stages of compilation: torch.compile and # XLA compilation. Using `mark_dynamic` can reduce the torch.compile # overhead by reusing the FX graph for different shapes. # However, the XLA graph will still require static shapes and needs to # be re-compiled for every different shapes. This overhead is inevitable # in the first run, but can be skipped afterwards as we cache the XLA # graphs in the disk (VLLM_XLA_CACHE_PATH). if exec_mode.is_prefill(): # Prefll torch._dynamo.mark_dynamic(token_ids, 1) torch._dynamo.mark_dynamic(position_ids, 1) torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 1) else: # Decode torch._dynamo.mark_dynamic(token_ids, 0) torch._dynamo.mark_dynamic(position_ids, 0) torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 0) torch._dynamo.mark_dynamic(attn_metadata.context_lens, 0) torch._dynamo.mark_dynamic(attn_metadata.block_tables, 0) with set_forward_context(attn_metadata, self.vllm_config, 0): assert self.model is not None self.model(token_ids, position_ids, attn_metadata, kv_caches) def capture_model(self) -> None: """Compile the model.""" # Prefill logger.info( "Compiling the model with different input shapes for prefill:") start = time.time() for batch_size in [1]: seq_len = 16 while seq_len <= self.model_config.max_model_len: self.dummy_run(self.kv_caches, batch_size, seq_len, exec_mode=ExecutionMode.PREFILL) xm.wait_device_ops() logger.info(" batch_size: %d, seq_len: %d", batch_size, seq_len) num_tokens = batch_size * seq_len if num_tokens >= self.scheduler_config.max_num_batched_tokens: break seq_len = seq_len * 2 end = time.time() logger.info(" -- Compilation for prefill done in %.2f [secs].", end - start) # Prefix prefill if self.scheduler_config.enable_chunked_prefill: logger.info("Compiling the model with different input shapes for " "prefix prefill:") start = time.time() for batch_size in [1]: seq_len = 16 while seq_len <= self.model_config.max_model_len: self.dummy_run(self.kv_caches, batch_size, seq_len, exec_mode=ExecutionMode.PREFIX_PREFILL) xm.wait_device_ops() logger.info(" batch_size: %d, seq_len: %d", batch_size, seq_len) num_tokens = batch_size * seq_len if (num_tokens >= self.scheduler_config.max_num_batched_tokens): break seq_len = seq_len * 2 end = time.time() logger.info( " -- Compilation for prefix prefill done in %.2f [secs].", end - start) # Decode logger.info( "Compiling the model with different input shapes for decode:") start = time.time() seq_len = 1 batch_size = 8 # Must be in sync with _get_padded_batch_size() while True: self.dummy_run(self.kv_caches, batch_size, seq_len, exec_mode=ExecutionMode.DECODE) xm.wait_device_ops() logger.info(" batch_size: %d, seq_len: %d", batch_size, seq_len) if batch_size >= self.scheduler_config.max_num_seqs: break batch_size = batch_size + 16 if batch_size >= 16 else batch_size * 2 end = time.time() logger.info(" -- Compilation for decode done in %.2f [secs].", end - start) def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: """ Initialize KV cache based on `kv_cache_config`. Args: kv_cache_config: Configuration for the KV cache, including the KV cache size of each layer """ if len(kv_cache_config.groups) > 1: raise NotImplementedError( "Hybrid models with more than one KV cache type are not " "supported yet.") kv_caches: Dict[str, torch.Tensor] = {} for layer_name, layer_spec in kv_cache_config.kv_cache_spec.items(): tensor_config = kv_cache_config.tensors[layer_name] assert tensor_config.size % layer_spec.page_size_bytes == 0 num_blocks = tensor_config.size // layer_spec.page_size_bytes if isinstance(layer_spec, FullAttentionSpec): kv_cache_shape = PallasAttentionBackend.get_kv_cache_shape( num_blocks, layer_spec.block_size, layer_spec.num_kv_heads, layer_spec.head_size) dtype = layer_spec.dtype tpu_k_cache = torch.zeros(kv_cache_shape, dtype=dtype, device=self.device) tpu_v_cache = torch.zeros_like(tpu_k_cache) kv_caches[layer_name] = (tpu_k_cache, tpu_v_cache) else: raise NotImplementedError bind_kv_cache( kv_caches, self.vllm_config.compilation_config.static_forward_context, self.kv_caches) class ModelWrapperV1(nn.Module): def __init__(self, model: nn.Module): super().__init__() self.model = model def forward( self, token_ids: torch.Tensor, position_ids: torch.Tensor, attn_metadata: AttentionMetadata, kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], ) -> torch.Tensor: """Executes the forward pass of the model and samples the next token. Args: token_ids: The input token IDs of shape [batch_size, seq_len]. position_ids: The input position IDs of shape [batch_size, seq_len]. attn_metadata: The Pallas attention metadata. input_lens: The actual input lengths of shape [batch_size]. t: The sampling temperature of shape [batch_size]. p: The top-p probability of shape [batch_size]. num_samples: Number of samples to draw from each logits vector. kv_caches: The key and value caches. They can be None during the memory profiling at initialization. """ # Skip this in memory profiling at initialization. if attn_metadata is not None and kv_caches[0][0].numel() > 0: # index_copy_(slot_mapping) only works when the inserted dimension # is 0. However, the KV cache in the Pallas backend has the shape # [num_kv_heads, num_blocks, block_size, head_size]. To make it # work, we need to flatten the first three dimensions and modify # the slot_mapping accordingly. num_kv_heads, num_blocks, block_size, _ = kv_caches[0][0].shape slot_mapping = attn_metadata.slot_mapping slot_mapping = slot_mapping.flatten() head_indicies = torch.arange(0, num_kv_heads, device=slot_mapping.device, dtype=slot_mapping.dtype) head_indicies *= block_size * num_blocks slot_mapping = slot_mapping.repeat_interleave(num_kv_heads).view( -1, num_kv_heads) slot_mapping = slot_mapping + head_indicies.view(1, -1) slot_mapping = slot_mapping.flatten() attn_metadata.slot_mapping = slot_mapping assert self.model is not None hidden_states = self.model( token_ids, position_ids, kv_caches, attn_metadata, ) hidden_states = hidden_states.flatten(0, 1) logits = self.model.compute_logits(hidden_states, None) # Greedy sampling. argmax_token_ids = torch.argmax(logits, dim=-1, keepdim=True) argmax_token_ids = argmax_token_ids.squeeze(dim=-1) return argmax_token_ids def swap_positions(b: InputBatch, id_1, id_2): assert id_1 != id_2 req_id_1 = b.req_ids[id_1] req_id_2 = b.req_ids[id_2] assert req_id_1 is not None assert req_id_2 is not None assert id_1 == b.req_id_to_index[req_id_1] assert id_2 == b.req_id_to_index[req_id_2] b.req_ids[id_1], b.req_ids[id_2] = b.req_ids[id_2], b.req_ids[id_1] b.req_id_to_index[req_id_1], b.req_id_to_index[ req_id_2] = b.req_id_to_index[req_id_2], b.req_id_to_index[req_id_1] ids = [id_1, id_2] rev_ids = [id_2, id_1] b.num_tokens[ids] = b.num_tokens[rev_ids] b.token_ids_cpu[ids] = b.token_ids_cpu[rev_ids] b.num_prompt_tokens[ids] = b.num_prompt_tokens[rev_ids] b.num_computed_tokens_cpu[ids] = b.num_computed_tokens_cpu[rev_ids] b.block_table.swap_row(id_1, id_2) b.temperature_cpu[ids] = b.temperature_cpu[rev_ids] b.top_p_cpu[ids] = b.top_p_cpu[rev_ids] b.top_k_cpu[ids] = b.top_k_cpu[rev_ids] b.frequency_penalties_cpu[ids] = b.frequency_penalties_cpu[rev_ids] b.presence_penalties_cpu[ids] = b.presence_penalties_cpu[rev_ids] b.repetition_penalties_cpu[ids] = b.repetition_penalties_cpu[rev_ids] b.min_tokens[id_1], b.min_tokens[id_2] = b.min_tokens[id_2], b.min_tokens[ id_1] b.stop_token_ids[id_1], b.stop_token_ids[id_2] = b.stop_token_ids[ id_2], b.stop_token_ids[id_1] gen_1 = b.generators.pop(id_1, None) gen_2 = b.generators.pop(id_2, None) if gen_1 is not None: b.generators[id_2] = gen_1 if gen_2 is not None: b.generators[id_1] = gen_2 def ensure_decodes_first(b: InputBatch): num_reqs = b.num_reqs while True: # Find the first prompt index first_prompt_index = None for i in range(num_reqs): if b.num_computed_tokens_cpu[i] < b.num_prompt_tokens[i]: first_prompt_index = i break if first_prompt_index is None: break # Find the last decode index last_decode_index = None for i in reversed(range(num_reqs)): if b.num_computed_tokens_cpu[i] >= b.num_prompt_tokens[i]: last_decode_index = i break if last_decode_index is None: break # Sanity assert first_prompt_index != last_decode_index # Check if done if first_prompt_index > last_decode_index: break # Swap swap_positions(b, first_prompt_index, last_decode_index) def _get_padded_prompt_len(x: int) -> int: # NOTE(woosuk): The pallas FlashAttention kernel requires the sequence # length to be a multiple of 16. We pad the prompt length to the nearest # multiple of 16. This is also good for performance. if x <= 16: return 16 return 1 << (x - 1).bit_length() def _get_padded_batch_size(batch_size: int) -> int: # The GMM Pallas kernel requires num_tokens * topk to be a multiple of 16. # To meet this requirement in the simplest way, we set the minimal batch # size to 8. if batch_size <= 8: return 8 else: return ((batch_size + 15) // 16) * 16