From c3b6559a1019c73f5b2436d2d50cbaa3663e5382 Mon Sep 17 00:00:00 2001 From: iefgnoix Date: Fri, 28 Feb 2025 10:01:36 -0800 Subject: [PATCH] [V1][TPU] Integrate the new ragged paged attention kernel with vLLM v1 on TPU (#13379) Signed-off-by: Xiongfei Wei Signed-off-by: mgoin Co-authored-by: mgoin --- requirements-tpu.txt | 11 +- vllm/v1/attention/backends/pallas.py | 278 ++------ vllm/v1/outputs.py | 2 +- vllm/v1/worker/gpu_model_runner.py | 4 +- vllm/v1/worker/tpu_model_runner.py | 921 ++++++++------------------- vllm/v1/worker/tpu_worker.py | 6 +- 6 files changed, 335 insertions(+), 887 deletions(-) diff --git a/requirements-tpu.txt b/requirements-tpu.txt index 8bfbb2dda194..725b1a2e4a58 100644 --- a/requirements-tpu.txt +++ b/requirements-tpu.txt @@ -17,9 +17,8 @@ ray[default] --find-links https://storage.googleapis.com/libtpu-releases/index.html --find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html --find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html -torch @ https://download.pytorch.org/whl/nightly/cpu/torch-2.6.0.dev20241216%2Bcpu-cp39-cp39-linux_x86_64.whl ; python_version == "3.9" -torch @ https://download.pytorch.org/whl/nightly/cpu/torch-2.6.0.dev20241216%2Bcpu-cp310-cp310-linux_x86_64.whl ; python_version == "3.10" -torch @ https://download.pytorch.org/whl/nightly/cpu/torch-2.6.0.dev20241216%2Bcpu-cp311-cp311-linux_x86_64.whl ; python_version == "3.11" -torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250124-cp39-cp39-linux_x86_64.whl ; python_version == "3.9" -torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250124-cp310-cp310-linux_x86_64.whl ; python_version == "3.10" -torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250124-cp311-cp311-linux_x86_64.whl ; python_version == "3.11" + +torch==2.7.0.dev20250226+cpu +torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250226+cxx11-cp39-cp39-linux_x86_64.whl ; python_version == "3.9" +torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250226+cxx11-cp310-cp310-linux_x86_64.whl ; python_version == "3.10" +torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250226+cxx11-cp311-cp311-linux_x86_64.whl ; python_version == "3.11" diff --git a/vllm/v1/attention/backends/pallas.py b/vllm/v1/attention/backends/pallas.py index 37bf33f6e3e9..a9f7b3fd4471 100644 --- a/vllm/v1/attention/backends/pallas.py +++ b/vllm/v1/attention/backends/pallas.py @@ -4,13 +4,16 @@ from dataclasses import dataclass from typing import Any, Dict, List, Optional, Tuple, Type import torch -import torch_xla.experimental.custom_kernel # Required to register custom ops. +# Required to register custom ops. +import torch_xla.experimental.custom_kernel # noqa: F401 from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionLayer, - AttentionMetadata, AttentionType) + AttentionLayer, AttentionType) from vllm.attention.backends.utils import CommonAttentionState +NUM_QUERIES_PER_BLOCK = 16 +NUM_KV_PAGES_PER_BLOCK = 128 + class PallasAttentionBackend(AttentionBackend): @@ -47,47 +50,23 @@ class PallasAttentionBackend(AttentionBackend): ) -> None: raise RuntimeError("swap_blocks is not used for the TPU backend.") - @torch.compile(backend="openxla") - @staticmethod - def copy_blocks( - kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], - src_to_dists: Tuple[torch.Tensor, torch.Tensor], - ) -> None: - src_indices, dst_indices = src_to_dists - for k_cache, v_cache in kv_caches: - torch.ops.xla.dynamo_set_buffer_donor_(k_cache, True) - k_cache[:, dst_indices] = k_cache[:, src_indices] - torch.ops.xla.dynamo_set_buffer_donor_(v_cache, True) - v_cache[:, dst_indices] = v_cache[:, src_indices] - @dataclass -class PallasMetadata(AttentionMetadata): +class PallasMetadata: + # NOTE(sang): Definition of context_len, query_len, and seq_len. + # |---------- N-1 iteration --------| + # |---------------- N iteration ---------------------| + # |- tokenA -|......................|-- newTokens ---| + # |---------- context_len ----------| + # |-------------------- seq_len ---------------------| + # |-- query_len ---| - # Currently, input sequences can only contain all prefills - # or all decoding. - block_tables: Optional[torch.Tensor] = None - context_lens: Optional[torch.Tensor] = None - effective_query_lens: Optional[torch.Tensor] = None - - @property - def prefill_metadata(self) -> Optional["PallasMetadata"]: - if self.num_prefills == 0: - return None - - assert self.num_decode_tokens == 0 - return self - - @property - def decode_metadata(self) -> Optional["PallasMetadata"]: - if self.num_decode_tokens == 0: - return None - - assert self.num_prefills == 0 - assert self.num_prefill_tokens == 0 - assert self.block_tables is not None - assert self.context_lens is not None - return self + # Used in the PallasAttentionBackendImpl + slot_mapping: torch.Tensor + block_tables: torch.Tensor + context_lens: torch.Tensor + query_start_loc: torch.Tensor + num_seqs: int class PallasAttentionBackendImpl(AttentionImpl): @@ -105,10 +84,13 @@ class PallasAttentionBackendImpl(AttentionImpl): logits_soft_cap: Optional[float] = None, attn_type: str = AttentionType.DECODER, ) -> None: + if blocksparse_params is not None: + raise ValueError("Paged attention Pallas kernel does " + "not support block-sparse attention.") self.num_heads = num_heads self.head_size = head_size self.scale = float(scale) - self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads + self.num_kv_heads = num_kv_heads assert self.num_heads % self.num_kv_heads == 0 self.num_queries_per_kv = self.num_heads // self.num_kv_heads @@ -126,25 +108,6 @@ class PallasAttentionBackendImpl(AttentionImpl): raise NotImplementedError( "Attention logits soft-capping is not supported.") - if torch_xla.tpu.version() < 4: - raise NotImplementedError("TPU version must be 4 or higher.") - - self.megacore_mode = None - tpu_env = torch_xla.tpu.get_tpu_env() - tpu_type = (tpu_env.get("ACCELERATOR_TYPE", None) - or tpu_env.get("TYPE", None) - or tpu_env.get("TPU_ACCELERATOR_TYPE", None)) - assert tpu_type is not None - tpu_type = tpu_type.lower() - - if (("lite" not in tpu_type) and ("v6" not in tpu_type)): - if self.num_kv_heads % 2 == 0: - self.megacore_mode = "kv_head" - else: - # NOTE(woosuk): If the batch size is not a multiple of 2, the - # megacore mode will be None. - self.megacore_mode = "batch" - if attn_type != AttentionType.DECODER: raise NotImplementedError("Encoder self-attention and " "encoder/decoder cross-attention " @@ -164,135 +127,47 @@ class PallasAttentionBackendImpl(AttentionImpl): """Forward pass with Pallas attention. Args: - query: shape = [batch_size, seq_len, num_heads * head_size] - key: shape = [batch_size, seq_len, num_kv_heads * head_size] - value: shape = [batch_size, seq_len, num_kv_heads * head_size] - kv_cache[0] = [num_kv_heads, num_blocks, block_size, head_size] - kv_cache[1] = [num_kv_heads, num_blocks, block_size, head_size] - NOTE: kv_cache[0] and kv_cache[1] will be an empty tensor - with shape [0] for profiling run. + query: shape = [num_tokens, num_heads * head_size] + key: shape = [num_tokens, num_kv_heads * head_size] + value: shape = [num_tokens, num_kv_heads * head_size] + kv_cache = ([num_kv_heads, num_blocks, block_size, head_size], + [num_kv_heads, num_blocks, block_size, head_size]) attn_metadata: Metadata for attention. Returns: - shape = [batch_size, seq_len, num_heads * head_size] + shape = [num_tokens, num_heads * head_size] """ - - if attn_metadata is None: + # For determine_available_memory case. + if kv_cache[0].numel() == 0: if output is None: output = torch.ones_like(query) return output assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0 - batch_size, seq_len, hidden_size = query.shape - query = query.view(batch_size, seq_len, self.num_heads, self.head_size) - key = key.view(batch_size, seq_len, self.num_kv_heads, self.head_size) - value = value.view(batch_size, seq_len, self.num_kv_heads, - self.head_size) + num_tokens, hidden_size = query.shape + query = query.view(num_tokens, self.num_heads, self.head_size) + key = key.view(num_tokens, self.num_kv_heads, self.head_size) + value = value.view(num_tokens, self.num_kv_heads, self.head_size) + key_cache, value_cache = kv_cache if kv_cache[0].numel() > 0: slot_mapping = attn_metadata.slot_mapping - key_cache, value_cache = kv_cache write_to_kv_cache(key, value, key_cache, value_cache, slot_mapping) query = query * self.scale - if attn_metadata.num_prefills > 0: - if attn_metadata.block_tables is None: - # Prefill without paged KV cache. - assert seq_len % 16 == 0, ( - "Pallas FlashAttention kernel requires seq_len to be a " - f"multiple of 16 but got {seq_len}") + output = torch.ops.xla.ragged_paged_attention( + query, + key_cache, + value_cache, + attn_metadata.context_lens, + attn_metadata.block_tables, + attn_metadata.query_start_loc, + attn_metadata.num_seqs, + num_kv_pages_per_block=NUM_KV_PAGES_PER_BLOCK, + num_queries_per_block=NUM_QUERIES_PER_BLOCK, + use_kernel=False, + ) - # Handle GQA/MQA. - if self.num_kv_heads != self.num_heads: - key = key.repeat_interleave(self.num_queries_per_kv, - dim=-2) - key = key.view(batch_size, seq_len, self.num_heads, - self.head_size) - value = value.repeat_interleave(self.num_queries_per_kv, - dim=-2) - value = value.view(batch_size, seq_len, self.num_heads, - self.head_size) - # FlashAttention kernel requires the input shape to be - # [batch_size, num_heads, seq_len, d_model] - # while the input is [batch_size, seq_len, num_heads, d_model]. - # Permute the input to match the required format. - output = torch.ops.xla.flash_attention( - query.permute(0, 2, 1, 3), - key.permute(0, 2, 1, 3), - value.permute(0, 2, 1, 3), - True, - ) - output = output.permute(0, 2, 1, 3) - else: - # Prefill with paged KV cache. - # TODO(woosuk): Tune the below knobs. - num_kv_pages_per_compute_block = 16 - num_queries_per_compute_block = 16 - assert seq_len % num_queries_per_compute_block == 0 - output = torch.ops.xla.multi_queries_paged_attention( - query, - key_cache, - value_cache, - attn_metadata.context_lens, - attn_metadata.block_tables, - attn_metadata.effective_query_lens, - num_kv_pages_per_compute_block, - num_queries_per_compute_block, - use_kernel=True, - ) - else: - # Decoding run. - assert kv_cache[0].numel() > 0 - query = query.squeeze(dim=1) - pages_per_compute_block = 16 # TODO(woosuk): Tune this value. - - assert attn_metadata.block_tables is not None - assert attn_metadata.context_lens is not None - # NOTE(woosuk): The PagedAttention Pallas kernel stores the entire - # block table in SMEM. Therefore, if the block table is too large, - # the kernel compilation will fail. To avoid this, we split the - # batch dimension into smaller chunks and run the kernel multiple - # times. - MAX_SMEM_USAGE = 512 * 1024 - size_per_seq = 4 * attn_metadata.block_tables.shape[1] - max_num_seq = MAX_SMEM_USAGE // size_per_seq - - if batch_size <= max_num_seq: - output = paged_attention( - query, - key_cache, - value_cache, - attn_metadata.context_lens, - attn_metadata.block_tables, - pages_per_compute_block, - self.megacore_mode, - ) - else: - chunk_size = max_num_seq - # Make sure the chunk size is a multiple of 2. - chunk_size = chunk_size // 2 * 2 - num_chunks = (batch_size + chunk_size - 1) // chunk_size - - output = torch.empty_like(query) - for chunk_idx in range(num_chunks): - chunk_start = chunk_idx * chunk_size - chunk_end = chunk_start + chunk_size - # NOTE(woosuk): We skip this line because it causes Dynamo - # compilation error. Instead, we rely on the slice operation - # to handle the out-of-bound case. - # chunk_end = min(chunk_end, batch_size) - chunk_output = paged_attention( - query[chunk_start:chunk_end], - key_cache, - value_cache, - attn_metadata.context_lens[chunk_start:chunk_end], - attn_metadata.block_tables[chunk_start:chunk_end], - pages_per_compute_block, - self.megacore_mode, - ) - output[chunk_start:chunk_end] = chunk_output - - # Reshape the output tensor. - return output.reshape(batch_size, seq_len, hidden_size) + return output.reshape(num_tokens, hidden_size) def write_to_kv_cache( @@ -302,52 +177,21 @@ def write_to_kv_cache( value_cache: torch.Tensor, slot_mapping: torch.Tensor, ) -> None: + """ Write the key and values to the KV cache. + + Args: + key: shape = [num_tokens, num_kv_heads, head_size] + value: shape = [num_tokens, num_kv_heads, head_size] + k_cache = [num_kv_heads, num_blocks, block_size, head_size] + v_cache = [num_kv_heads, num_blocks, block_size, head_size] + + """ torch.ops.xla.dynamo_set_buffer_donor_(key_cache, True) torch.ops.xla.dynamo_set_buffer_donor_(value_cache, True) - key = key.flatten(0, 2) - value = value.flatten(0, 2) + key = key.flatten(0, 1) + value = value.flatten(0, 1) key_cache = key_cache.flatten(0, 2) value_cache = value_cache.flatten(0, 2) key_cache.index_copy_(0, slot_mapping, key) value_cache.index_copy_(0, slot_mapping, value) - - -def paged_attention( - query: torch.Tensor, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - context_lens: torch.Tensor, - block_tables: torch.Tensor, - pages_per_compute_block: int, - megacore_mode: Optional[str], -) -> torch.Tensor: - batch_size = query.shape[0] - if megacore_mode == "batch" and batch_size % 2 != 0: - megacore_mode = None - else: - megacore_mode = megacore_mode - - # NOTE(woosuk): A temporary workaround to avoid the error: - # "xla::paged_attention() Expected a value of type 'str' for - # argument 'megacore_mode' but instead found type 'NoneType'." - if megacore_mode is not None: - output = torch.ops.xla.paged_attention( - query, - key_cache, - value_cache, - context_lens, - block_tables, - pages_per_compute_block, - megacore_mode=megacore_mode, - ) - else: - output = torch.ops.xla.paged_attention( - query, - key_cache, - value_cache, - context_lens, - block_tables, - pages_per_compute_block, - ) - return output diff --git a/vllm/v1/outputs.py b/vllm/v1/outputs.py index 0c8eca38ade7..f461d52cc984 100644 --- a/vllm/v1/outputs.py +++ b/vllm/v1/outputs.py @@ -79,4 +79,4 @@ class ModelRunnerOutput: # [prompt_len, num_prompt_logprobs] # [prompt_len, num_prompt_logprobs] # [prompt_len] - prompt_logprobs_dict: Dict[str, LogprobsTensors] + prompt_logprobs_dict: Dict[str, Optional[LogprobsTensors]] diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 2730e6770dc3..e255becbefbf 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1071,12 +1071,12 @@ class GPUModelRunner(LoRAModelRunnerMixin): self, hidden_states: torch.Tensor, scheduler_output: "SchedulerOutput", - ) -> Dict[str, LogprobsTensors]: + ) -> Dict[str, Optional[LogprobsTensors]]: num_prompt_logprobs_dict = self.input_batch.num_prompt_logprobs if not num_prompt_logprobs_dict: return {} - prompt_logprobs_dict: Dict[str, LogprobsTensors] = {} + prompt_logprobs_dict: Dict[str, Optional[LogprobsTensors]] = {} # Since prompt logprobs are a rare feature, prioritize simple, # maintainable loop over optimal performance. diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index f7d72d26e045..d16a0a4165c7 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -1,8 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -import enum import time -from dataclasses import dataclass -from typing import TYPE_CHECKING, Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, cast from unittest.mock import patch import numpy as np @@ -21,7 +19,9 @@ from vllm.logger import init_logger from vllm.model_executor.model_loader import get_model from vllm.sampling_params import SamplingType from vllm.utils import LayerBlockType, cdiv, is_pin_memory_available -from vllm.v1.attention.backends.pallas import (PallasAttentionBackend, +from vllm.v1.attention.backends.pallas import (NUM_KV_PAGES_PER_BLOCK, + NUM_QUERIES_PER_BLOCK, + PallasAttentionBackend, PallasMetadata) from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, KVCacheSpec) @@ -37,36 +37,7 @@ logger = init_logger(__name__) # Here we utilize the behavior that out-of-bound index is ignored. # FIXME(woosuk): Find a more reliable way to prevent possible bugs. _PAD_SLOT_ID = 1_000_000_000 - - -class ExecutionMode(enum.Enum): - PREFILL = enum.auto() - DECODE = enum.auto() - PREFIX_PREFILL = enum.auto() - - def is_prefill(self) -> bool: - return self in (ExecutionMode.PREFILL, ExecutionMode.PREFIX_PREFILL) - - -@dataclass -class PromptDecodeInfo: - prompt_req_ids: List[str] - decode_req_ids: List[str] - prompt_scheduled_tokens: List[int] - - -@dataclass -class PromptData: - input_tokens: torch.Tensor - input_positions: torch.Tensor - attn_metadata: PallasMetadata - - -@dataclass -class DecodeData: - input_tokens: Optional[torch.Tensor] = None - input_positions: Optional[torch.Tensor] = None - attn_metadata: Optional[PallasMetadata] = None +INVALID_TOKEN_ID = -1 class TPUModelRunner: @@ -113,8 +84,6 @@ class TPUModelRunner: self.head_size = model_config.get_head_size() self.hidden_size = model_config.get_hidden_size() - self.model: Optional[nn.Module] = None - # Persistent batch. self.input_batch = InputBatch( max_num_reqs=self.max_num_reqs, @@ -134,50 +103,48 @@ class TPUModelRunner: # KV caches for forward pass self.kv_caches: List[Tuple[torch.Tensor, torch.Tensor]] = [] - # Cached torch/numpy tensors - self.num_swaps = 2 - self.cur_swap_id = 0 - self.input_ids_cpu = [] - self.input_ids_np = [] - self.input_positions_cpu = [] - self.input_positions_np = [] - self.slot_mapping_cpu = [] - self.slot_mapping_np = [] - self.prompt_context_lens_cpu = [] - self.prompt_effective_query_lens_cpu = [] - self.decode_context_lens_cpu = [] - self.decode_context_lens_np = [] - for _ in range(self.num_swaps): - self.input_ids_cpu.append( - torch.empty(self.max_num_tokens, - dtype=torch.int32, - device="cpu")) - self.input_ids_np.append(self.input_ids_cpu[-1].numpy()) + # Cached torch/numpy tensor + # The pytorch tensor and numpy array share the same buffer. + # Sometimes the numpy op is faster so we create both. + self.input_ids_cpu = torch.zeros(self.max_num_tokens, + dtype=torch.int32, + device="cpu") + self.input_ids_np = self.input_ids_cpu.numpy() - self.input_positions_cpu.append( - torch.empty(self.max_num_tokens, - dtype=torch.int32, - device="cpu")) - self.input_positions_np.append( - self.input_positions_cpu[-1].numpy()) + self.positions_cpu = torch.zeros(self.max_num_tokens, + dtype=torch.int32, + device="cpu") + self.positions_np = self.positions_cpu.numpy() - self.slot_mapping_cpu.append( - torch.empty(self.max_num_tokens, - dtype=torch.int64, - device="cpu")) - self.slot_mapping_np.append(self.slot_mapping_cpu[-1].numpy()) + self.slot_mapping_cpu = torch.zeros(self.max_num_tokens, + dtype=torch.int64, + device="cpu") + self.slot_mapping_np = self.slot_mapping_cpu.numpy() - self.prompt_context_lens_cpu.append( - torch.empty((1), dtype=torch.int32, device="cpu")) - self.prompt_effective_query_lens_cpu.append( - torch.empty((1), dtype=torch.int32, device="cpu")) + # self.input_batch.block_table has a shape of [max_num_reqs, + # max_num_blocks_per_req]. To reduce the number of recompilation, + # we want the block_table.shape[0] to be num_tokens. + # To make the block_table to be compatible with the paged attention + # kernel, we want the block_table[1] to be multiple of + # NUM_KV_PAGES_PER_BLOCK. + padded_max_num_blocks_per_req = _get_padded_number( + self.max_num_blocks_per_req, NUM_KV_PAGES_PER_BLOCK) + self.block_table_cpu = torch.zeros( + (self.max_num_tokens, padded_max_num_blocks_per_req), + dtype=self.input_batch.block_table.get_cpu_tensor().dtype, + device="cpu") - self.decode_context_lens_cpu.append( - torch.empty(self.max_num_tokens, - dtype=torch.int32, - device="cpu")) - self.decode_context_lens_np.append( - self.decode_context_lens_cpu[-1].numpy()) + self.query_start_loc_cpu = torch.zeros(self.max_num_tokens + 1, + dtype=torch.int32, + device="cpu", + pin_memory=self.pin_memory) + self.query_start_loc_np = self.query_start_loc_cpu.numpy() + + self.seq_lens_cpu = torch.zeros(self.max_num_tokens, + dtype=torch.int32, + device="cpu", + pin_memory=self.pin_memory) + self.seq_lens_np = self.seq_lens_cpu.numpy() # Range tensor with values [0 .. self.max_num_tokens - 1]. # Used to initialize positions / context_lens / seq_lens @@ -191,7 +158,7 @@ class TPUModelRunner: the input GPU tensors for the model. Returns: - True if there is a new/resumed/paused/finished request in the batch. + True if there is a new/resumed/paused/finished request. If False, we can skip copying SamplingMetadata to the GPU. """ # Remove finished requests from the cached states. @@ -303,9 +270,6 @@ class TPUModelRunner: self.input_batch.condense(removed_req_indices) return len(unscheduled_req_ids) > 0 or len(req_ids_to_add) > 0 - def swap_step(self): - self.cur_swap_id = (self.cur_swap_id + 1) % self.num_swaps - def get_model(self) -> nn.Module: assert self.model is not None return self.model @@ -345,238 +309,124 @@ class TPUModelRunner: return kv_cache_spec - def _get_prompts_and_decodes( - self, - scheduler_output: "SchedulerOutput", - ) -> PromptDecodeInfo: + def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens assert total_num_scheduled_tokens > 0 num_reqs = self.input_batch.num_reqs assert num_reqs > 0 - # Traverse decodes first - decode_req_ids = [] - for i in range(num_reqs): - req_id = self.input_batch.req_ids[i] + # Get the number of scheduled tokens for each request. + num_scheduled_tokens_per_req = [] + max_num_scheduled_tokens_all_reqs = 0 + for req_id in self.input_batch.req_ids[:num_reqs]: assert req_id is not None + num_tokens = scheduler_output.num_scheduled_tokens[req_id] + num_scheduled_tokens_per_req.append(num_tokens) + max_num_scheduled_tokens_all_reqs = max( + max_num_scheduled_tokens_all_reqs, num_tokens) + num_scheduled_tokens_per_req = np.array(num_scheduled_tokens_per_req, + dtype=np.int32) + assert max_num_scheduled_tokens_all_reqs > 0 - num_computed_tokens = self.input_batch.num_computed_tokens_cpu[i] - num_prompt_tokens = self.input_batch.num_prompt_tokens[i] - num_scheduled_tokens = scheduler_output.num_scheduled_tokens[ - req_id] + # Get request indices. + # E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2] + # For each scheduled token, what are the corresponding req index. + req_indices = np.repeat(self.arange_np[:num_reqs], + num_scheduled_tokens_per_req) - if num_computed_tokens < num_prompt_tokens: - # This is prompt - break + # Get batched arange. + # E.g., [2, 5, 3] -> [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] + # For each scheduled token, what is its position in corresponding req. + arange = np.concatenate( + [self.arange_np[:n] for n in num_scheduled_tokens_per_req]) - # This is decode - assert num_scheduled_tokens == 1 - decode_req_ids.append(req_id) + # Get positions. + positions_np = self.positions_np[:total_num_scheduled_tokens] + np.add(self.input_batch.num_computed_tokens_cpu[req_indices], + arange, + out=positions_np) - # Traverse prompts - prompt_req_ids = [] - prompt_scheduled_tokens = [] - for i in range(len(decode_req_ids), num_reqs): - req_id = self.input_batch.req_ids[i] - assert req_id is not None + # Get token indices. + # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] + # -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2] + # where M is the max_model_len. + token_indices = (positions_np + + req_indices * self.input_batch.token_ids_cpu.shape[1]) - num_computed_tokens = self.input_batch.num_computed_tokens_cpu[i] - num_prompt_tokens = self.input_batch.num_prompt_tokens[i] - num_scheduled_tokens = scheduler_output.num_scheduled_tokens[ - req_id] - - # Must be prompt - assert num_computed_tokens < num_prompt_tokens - - prompt_req_ids.append(req_id) - prompt_scheduled_tokens.append(num_scheduled_tokens) - - return PromptDecodeInfo(prompt_req_ids, decode_req_ids, - prompt_scheduled_tokens) - - def _prepare_prompt(self, req_index: int, - num_scheduled_tokens: int) -> PromptData: - num_computed_tokens = self.input_batch.num_computed_tokens_cpu[ - req_index] - num_prompt_tokens = self.input_batch.num_prompt_tokens[req_index] - - # Must be prompt - assert num_computed_tokens < num_prompt_tokens - - # Prompt len - prompt_len = num_scheduled_tokens - padded_prompt_len = _get_padded_prompt_len(prompt_len) - assert padded_prompt_len <= self.max_model_len - - # Seq len - seq_len = num_computed_tokens + prompt_len - padded_seq_len = num_computed_tokens + padded_prompt_len - - # Input tokens - input_tokens_cpu = self.input_batch.token_ids_cpu_tensor[ - req_index, num_computed_tokens:padded_seq_len] - input_tokens_cpu[prompt_len:] = 0 - - # Input positions - input_positions_np = self.input_positions_np[ - self.cur_swap_id][:padded_prompt_len] - np.add(num_computed_tokens, - self.arange_np[:padded_prompt_len], - out=input_positions_np) - input_positions_np[prompt_len:] = 0 - - # Slot mapping - block_table_np = \ - self.input_batch.block_table.get_numpy_array() - block_numbers_np = block_table_np[req_index, input_positions_np // - self.block_size] - block_offsets_np = input_positions_np % self.block_size - - slot_mapping_np = self.slot_mapping_np[ - self.cur_swap_id][:padded_prompt_len] - np.add(block_numbers_np * self.block_size, - block_offsets_np, - out=slot_mapping_np) - slot_mapping_np[prompt_len:] = _PAD_SLOT_ID - - # Block table - block_table_cpu = None - if num_computed_tokens > 0: - block_table_cpu = self.input_batch.block_table.get_cpu_tensor() - block_table_cpu = block_table_cpu[req_index] - - # Context len - self.prompt_context_lens_cpu[self.cur_swap_id][0] = 0 - if num_computed_tokens > 0: - self.prompt_context_lens_cpu[self.cur_swap_id][0] = seq_len - - # Effective query len - self.prompt_effective_query_lens_cpu[self.cur_swap_id][0] = prompt_len - - # Get final tensors - input_tokens = input_tokens_cpu.reshape(1, -1).to(self.device) - input_positions = self.input_positions_cpu[ - self.cur_swap_id][:padded_prompt_len].reshape(1, - -1).to(self.device) - slot_mapping = self.slot_mapping_cpu[ - self.cur_swap_id][:padded_prompt_len].reshape(1, - -1).to(self.device) - block_table = block_table_cpu.reshape(1, -1).to( - self.device) if block_table_cpu is not None else None - - context_lens = self.prompt_context_lens_cpu[self.cur_swap_id].to( - self.device) - effective_query_lens = self.prompt_effective_query_lens_cpu[ - self.cur_swap_id].to(self.device) - - self.swap_step() - - # Attn metadata - attn_metadata = PallasMetadata( - num_prefills=1, - num_prefill_tokens=0, # NOTE: This is not used. - num_decode_tokens=0, - slot_mapping=slot_mapping, - multi_modal_placeholder_index_maps=None, - enable_kv_scales_calculation=True, - block_tables=block_table, - context_lens=context_lens, - effective_query_lens=effective_query_lens, - ) - - return PromptData(input_tokens, input_positions, attn_metadata) - - def _prepare_decode( - self, - decode_req_ids: List[str], - ) -> DecodeData: - # Batch size - batch_size = len(decode_req_ids) - padded_batch_size = _get_padded_batch_size(batch_size) - assert padded_batch_size <= self.max_model_len - - # Init [0 .. batch_size - 1] - req_indices_np = self.arange_np[:padded_batch_size] - - # Input positions - input_positions_np = self.input_positions_np[ - self.cur_swap_id][:padded_batch_size] - np.add(self.input_batch.num_computed_tokens_cpu[:padded_batch_size], - 0, - out=input_positions_np) - input_positions_np[batch_size:] = 0 - input_positions_cpu = self.input_positions_cpu[ - self.cur_swap_id][:padded_batch_size] - - # Input tokens - token_indices_np = ( - input_positions_np + - req_indices_np * self.input_batch.token_ids_cpu.shape[1]) - input_tokens_cpu = self.input_ids_cpu[ - self.cur_swap_id][:padded_batch_size] + # NOTE(woosuk): We use torch.index_select instead of np.take here + # because torch.index_select is much faster than np.take for large + # tensors. torch.index_select(self.input_batch.token_ids_cpu_tensor.flatten(), 0, - torch.from_numpy(token_indices_np), - out=input_tokens_cpu) - input_tokens_cpu[batch_size:] = 0 - - # Slot mapping - block_table_indices_np = ( - req_indices_np * self.max_num_blocks_per_req + - input_positions_np // self.block_size) + torch.from_numpy(token_indices), + out=self.input_ids_cpu[:total_num_scheduled_tokens]) + # Calculate the slot mapping. + # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] + # -> [0, 0, K, K, K + 1, K + 1, K + 2, 2 * K, 2 * K, 2 * K + 1] + # where K is the max_num_blocks_per_req and the block size is 2. + # NOTE(woosuk): We can't simply use `token_indices // block_size` here + # because M (max_model_len) is not necessarily divisible by block_size. + # req_indices: # E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2] + block_table_indices = (req_indices * self.max_num_blocks_per_req + + positions_np // self.block_size) + # NOTE(woosuk): We use torch.index_select instead of np.take here + # because torch.index_select is much faster than np.take for large + # tensors. block_table_cpu = self.input_batch.block_table.get_cpu_tensor() + block_numbers = block_table_cpu.flatten()[block_table_indices].numpy() + block_offsets = positions_np % self.block_size + np.add(block_numbers * self.block_size, + block_offsets, + out=self.slot_mapping_np[:total_num_scheduled_tokens]) - block_numbers_np = block_table_cpu.flatten( - )[block_table_indices_np].numpy() + # Prepare the attention metadata. + self.query_start_loc_np[0] = 0 + np.cumsum(num_scheduled_tokens_per_req, + out=self.query_start_loc_np[1:num_reqs + 1]) - block_offsets_np = input_positions_np % self.block_size + self.seq_lens_np[:num_reqs] = ( + self.input_batch.num_computed_tokens_cpu[:num_reqs] + + num_scheduled_tokens_per_req) - slot_mapping_np = self.slot_mapping_np[ - self.cur_swap_id][:padded_batch_size] - np.add(block_numbers_np * self.block_size, - block_offsets_np, - out=slot_mapping_np) - slot_mapping_np[batch_size:] = _PAD_SLOT_ID + # Do the padding and copy the tensors to the TPU. + padded_total_num_scheduled_tokens = _get_padded_number( + total_num_scheduled_tokens, NUM_QUERIES_PER_BLOCK) + self.input_ids = self.input_ids_cpu[: + padded_total_num_scheduled_tokens].to( + self.device) + self.position_ids = self.positions_cpu[: + padded_total_num_scheduled_tokens].to( + self.device) + self.slot_mapping_cpu[total_num_scheduled_tokens:] = _PAD_SLOT_ID + slot_mapping = self.slot_mapping_cpu[: + padded_total_num_scheduled_tokens].to( + self.device) + padded_block_table = self.block_table_cpu[: + padded_total_num_scheduled_tokens] + padded_block_table[:num_reqs, :self.max_num_blocks_per_req] = ( + self.input_batch.block_table.get_cpu_tensor()[:num_reqs]) + padded_block_table = padded_block_table.to(self.device) + query_start_loc = self.query_start_loc_cpu[: + padded_total_num_scheduled_tokens + + 1].to(self.device) + seq_lens = self.seq_lens_cpu[:padded_total_num_scheduled_tokens].to( + self.device) - block_table_cpu = block_table_cpu[:padded_batch_size] - - # Context lens - context_lens_np = self.decode_context_lens_np[ - self.cur_swap_id][:padded_batch_size] - np.add(self.input_batch.num_computed_tokens_cpu[:padded_batch_size], - 1, - out=context_lens_np) - context_lens_np[batch_size:] = 0 - - # Get final tensors - input_tokens = input_tokens_cpu.reshape(-1, 1).to(self.device) - input_positions = input_positions_cpu.reshape(-1, 1).to(self.device) - slot_mapping = self.slot_mapping_cpu[ - self.cur_swap_id][:padded_batch_size].reshape(-1, - 1).to(self.device) - block_table = block_table_cpu.to(self.device) - context_lens = self.decode_context_lens_cpu[ - self.cur_swap_id][:padded_batch_size].to(self.device) - - self.swap_step() - - # Attn metadata attn_metadata = PallasMetadata( - num_prefills=0, - num_prefill_tokens=0, - num_decode_tokens=padded_batch_size, slot_mapping=slot_mapping, - multi_modal_placeholder_index_maps=None, - enable_kv_scales_calculation=True, - block_tables=block_table, - context_lens=context_lens, - effective_query_lens=None, + block_tables=padded_block_table, + context_lens=seq_lens, + query_start_loc=query_start_loc, + num_seqs=num_reqs, ) - - return DecodeData(input_tokens=input_tokens, - input_positions=input_positions, - attn_metadata=attn_metadata) + # NOTE(woosuk): Due to chunked prefills, there can be at most 1 partial + # request in the batch. While we should not sample any token from this + # partial request, we do so for simplicity. We will ignore the sampled + # token from the partial request. + # TODO: Support prompt logprobs. + logits_indices = query_start_loc[1:] - 1 + return attn_metadata, logits_indices @torch.no_grad() def execute_model( @@ -586,118 +436,81 @@ class TPUModelRunner: # Update cached state self._update_states(scheduler_output) - # If necessary, swap decodes/prompts to have all decodes on the start - ensure_decodes_first(self.input_batch) + # Prepare inputs + attn_metadata, logits_indices = self._prepare_inputs(scheduler_output) + total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens - # Prepare prompts/decodes info - pd_info = self._get_prompts_and_decodes(scheduler_output) + # Run the decoder + with set_forward_context(attn_metadata, self.vllm_config): + hidden_states = self.model( + token_ids=self.input_ids, + position_ids=self.position_ids, + kv_caches=self.kv_caches, + ) + hidden_states = hidden_states[:total_num_scheduled_tokens] + num_reqs = self.input_batch.num_reqs + logits_indices = logits_indices[:num_reqs] + hidden_states = hidden_states[logits_indices] + logits = self.model.compute_logits(hidden_states, None) + selected_token_ids = torch.argmax(logits, dim=-1, keepdim=True) - # Init - num_prompts = len(pd_info.prompt_req_ids) - num_decodes = len(pd_info.decode_req_ids) - decode_data = None - sampled_token_ids = [0] * self.input_batch.num_reqs - - # Run each prompt individually - is_first = True - for i in range(num_prompts): - req_id = pd_info.prompt_req_ids[i] - req_index = num_decodes + i - assert req_index == self.input_batch.req_id_to_index[ - req_id] # TODO: Remove + # Then, let's update the cache state. + request_seq_lens: List[Tuple[int, CachedRequestState, int]] = [] + for i, req_id in zip(range(num_reqs), self.input_batch.req_ids): + assert req_id is not None req_state = self.requests[req_id] - num_scheduled_tokens = pd_info.prompt_scheduled_tokens[i] - prompt_len = num_scheduled_tokens - seq_len = req_state.num_computed_tokens + num_scheduled_tokens + seq_len = (req_state.num_computed_tokens + + scheduler_output.num_scheduled_tokens[req_id]) + if seq_len >= req_state.num_tokens: + request_seq_lens.append((i, req_state, seq_len)) + else: + # Ignore the sampled token from the partial request. + # Rewind the generator state as if the token was not sampled. + generator = self.input_batch.generators.get(i) + if generator is not None: + # This relies on cuda-specific torch-internal impl details + generator.set_offset(generator.get_offset() - 4) - # Prepare first prompt - if is_first: - prompt_data = self._prepare_prompt(req_index, - num_scheduled_tokens) - is_first = False + # num_reqs entries should be non-None + assert all( + req_id is not None for req_id in + self.input_batch.req_ids[:num_reqs]), "req_ids contains None" + req_ids = cast(List[str], self.input_batch.req_ids[:num_reqs]) - # Run forward pass - with set_forward_context(prompt_data.attn_metadata, - self.vllm_config): - assert self.model is not None - selected_token_ids = self.model(prompt_data.input_tokens, - prompt_data.input_positions, - self.kv_caches) - - # In parallel to TPU execution, prepare the next iteration - if i < num_prompts - 1: - # There is next prompt => prepare it - prompt_data = self._prepare_prompt( - req_index + 1, pd_info.prompt_scheduled_tokens[i + 1]) - elif i == num_prompts - 1 and num_decodes > 0: - # There is next decode => prepare it - decode_data = self._prepare_decode(pd_info.decode_req_ids) - - # Update cached state (if prompt is fully done) - if seq_len >= len(req_state.prompt_token_ids): - # Transfer sampled tokens from TPU to CPU - selected_token_ids_cpu = selected_token_ids.cpu() - - # Get output token - token_id = selected_token_ids_cpu[prompt_len - 1].item() - sampled_token_ids[req_index] = token_id - - # Add output token to the request - self.input_batch.token_ids_cpu[req_index, seq_len] = token_id - self.input_batch.num_tokens[req_index] += 1 - req_state.output_token_ids.append(token_id) - - # Run decodes (a single batch) - if num_decodes > 0: - - # Prepare decode (if was not yet prepared) - if decode_data is None: - decode_data = self._prepare_decode(pd_info.decode_req_ids) - - # Run forward pass - with set_forward_context(decode_data.attn_metadata, - self.vllm_config): - assert self.model is not None - selected_token_ids = self.model(decode_data.input_tokens, - decode_data.input_positions, - self.kv_caches) - - # Transfer sampled tokens from TPU to CPU - decode_token_ids_cpu = selected_token_ids.cpu() - # Convert to list - decode_token_ids_list = decode_token_ids_cpu.tolist() - - # Update cached state for each decode request - for i in range(num_decodes): - req_id = pd_info.decode_req_ids[i] - req_index = i - assert req_index == self.input_batch.req_id_to_index[ - req_id] # TODO: Remove - req_state = self.requests[req_id] - seq_len = req_state.num_computed_tokens + 1 - - token_id = decode_token_ids_list[i] - sampled_token_ids[req_index] = token_id - - self.input_batch.token_ids_cpu[req_index, seq_len] = token_id - self.input_batch.num_tokens[req_index] += 1 - req_state.output_token_ids.append(token_id) - - # Create output. - all_req_ids = pd_info.decode_req_ids + pd_info.prompt_req_ids prompt_logprobs_dict: Dict[str, Optional[LogprobsTensors]] = {} - for req_id in all_req_ids: + for req_id in self.input_batch.req_ids[:num_reqs]: prompt_logprobs_dict[req_id] = None + max_gen_len = selected_token_ids.shape[-1] + if max_gen_len == 1: + valid_sampled_token_ids = selected_token_ids.tolist() + for i, req_state, seq_len in request_seq_lens: + token_id = valid_sampled_token_ids[i][0] + self.input_batch.token_ids_cpu[i, seq_len] = token_id + req_state.output_token_ids.append(token_id) + self.input_batch.num_tokens[i] += 1 + else: + valid_mask = selected_token_ids != INVALID_TOKEN_ID + gen_lens = valid_mask.sum(dim=1).tolist() + valid_sampled_token_ids = [ + seq.tolist() + for seq in selected_token_ids[valid_mask].split(gen_lens) + ] + self.input_batch.num_tokens[:num_reqs] += gen_lens + for i, req_state, seq_len in request_seq_lens: + target_slice = slice(seq_len - gen_lens[i] + 1, seq_len + 1) + self.input_batch.token_ids_cpu[ + i, target_slice] = valid_sampled_token_ids[i] + req_state.output_token_ids.extend(valid_sampled_token_ids[i]) + model_runner_output = ModelRunnerOutput( - req_ids=all_req_ids, + req_ids=req_ids, req_id_to_index=self.input_batch.req_id_to_index, - sampled_token_ids=[[token_id] for token_id in sampled_token_ids], + sampled_token_ids=valid_sampled_token_ids, spec_token_ids=None, logprobs=None, - prompt_logprobs_dict=prompt_logprobs_dict, # type: ignore[arg-type] + prompt_logprobs_dict=prompt_logprobs_dict, ) - return model_runner_output def load_model(self) -> None: @@ -731,185 +544,63 @@ class TPUModelRunner: self, kv_caches, num_tokens: int, - seq_len: Optional[int] = None, - exec_mode: Optional[ExecutionMode] = None, ) -> None: - assert seq_len is not None - assert exec_mode is not None + input_ids = torch.zeros(num_tokens, + dtype=torch.int32, + device=self.device) + position_ids = torch.zeros(num_tokens, + dtype=torch.int32, + device=self.device) + slot_mapping = torch.zeros(num_tokens, + dtype=torch.int64, + device=self.device) + block_tables = torch.zeros((num_tokens, self.block_table_cpu.shape[1]), + dtype=torch.int32, + device=self.device) + query_lens = [1] * num_tokens + query_start_loc = torch.cumsum(torch.tensor([0] + query_lens, + dtype=torch.int32), + dim=0, + dtype=torch.int32).to(self.device) + context_lens = torch.ones((num_tokens, ), + dtype=torch.int32, + device=self.device) + attn_metadata = PallasMetadata( + slot_mapping=slot_mapping, + block_tables=block_tables, + context_lens=context_lens, + query_start_loc=query_start_loc, + num_seqs=num_tokens, + ) - exec_mode = ExecutionMode(exec_mode) - if exec_mode.is_prefill(): - seq_len = (seq_len + 15) // 16 * 16 - token_ids = torch.zeros((num_tokens, seq_len), - dtype=torch.int32, - device=self.device) - position_ids = torch.zeros((num_tokens, seq_len), - dtype=torch.int32, - device=self.device) - slot_mapping = torch.zeros((num_tokens, seq_len), - dtype=torch.int64, - device=self.device) - if exec_mode == ExecutionMode.PREFILL: - attn_metadata = PallasMetadata( - num_prefills=num_tokens, - num_prefill_tokens=num_tokens * seq_len, - num_decode_tokens=0, - slot_mapping=slot_mapping, - multi_modal_placeholder_index_maps=None, - enable_kv_scales_calculation=True, - block_tables=None, - context_lens=None, - effective_query_lens=None, - ) - - else: - context_lens = torch.ones((num_tokens, ), - dtype=torch.int32, - device=self.device) - - block_tables = torch.zeros( - (num_tokens, self.max_num_blocks_per_req), - dtype=torch.int32, - device=self.device) - - effective_query_lens = torch.ones_like(context_lens) - - attn_metadata = PallasMetadata( - num_prefills=num_tokens, - num_prefill_tokens=num_tokens * seq_len, - num_decode_tokens=0, - slot_mapping=slot_mapping, - multi_modal_placeholder_index_maps=None, - enable_kv_scales_calculation=True, - block_tables=block_tables, - context_lens=context_lens, - effective_query_lens=effective_query_lens, - ) - else: - assert seq_len == 1 - token_ids = torch.zeros((num_tokens, seq_len), - dtype=torch.int32, - device=self.device) - position_ids = torch.zeros((num_tokens, seq_len), - dtype=torch.int32, - device=self.device) - slot_mapping = torch.zeros((num_tokens, seq_len), - dtype=torch.int64, - device=self.device) - block_tables = torch.zeros( - (num_tokens, self.max_num_blocks_per_req), - dtype=torch.int32, - device=self.device) - context_lens = torch.ones((num_tokens, ), - dtype=torch.int32, - device=self.device) - attn_metadata = PallasMetadata( - num_prefills=0, - num_prefill_tokens=0, - num_decode_tokens=num_tokens * seq_len, - slot_mapping=slot_mapping, - multi_modal_placeholder_index_maps=None, - enable_kv_scales_calculation=True, - block_tables=block_tables, - context_lens=context_lens, - ) - - # NOTE(woosuk): There are two stages of compilation: torch.compile and - # XLA compilation. Using `mark_dynamic` can reduce the torch.compile - # overhead by reusing the FX graph for different shapes. - # However, the XLA graph will still require static shapes and needs to - # be re-compiled for every different shapes. This overhead is inevitable - # in the first run, but can be skipped afterwards as we cache the XLA - # graphs in the disk (VLLM_XLA_CACHE_PATH). - if exec_mode.is_prefill(): - # Prefll - torch._dynamo.mark_dynamic(token_ids, 1) - torch._dynamo.mark_dynamic(position_ids, 1) - torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 1) - else: - # Decode - torch._dynamo.mark_dynamic(token_ids, 0) - torch._dynamo.mark_dynamic(position_ids, 0) - torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 0) - torch._dynamo.mark_dynamic(attn_metadata.context_lens, 0) - torch._dynamo.mark_dynamic(attn_metadata.block_tables, 0) + torch._dynamo.mark_dynamic(input_ids, 0) + torch._dynamo.mark_dynamic(position_ids, 0) + torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 0) + torch._dynamo.mark_dynamic(attn_metadata.block_tables, 0) + torch._dynamo.mark_dynamic(attn_metadata.query_start_loc, 0) + torch._dynamo.mark_dynamic(attn_metadata.context_lens, 0) with set_forward_context(attn_metadata, self.vllm_config, 0): assert self.model is not None - self.model(token_ids, position_ids, kv_caches) + self.model(input_ids, position_ids, kv_caches) def capture_model(self) -> None: """Compile the model.""" - # Prefill - logger.info( - "Compiling the model with different input shapes for prefill:") - start = time.time() - for batch_size in [1]: - seq_len = 16 - while seq_len <= self.model_config.max_model_len: - self.dummy_run(self.kv_caches, - batch_size, - seq_len, - exec_mode=ExecutionMode.PREFILL) - xm.wait_device_ops() - logger.info(" batch_size: %d, seq_len: %d", batch_size, - seq_len) - num_tokens = batch_size * seq_len - if num_tokens >= self.scheduler_config.max_num_batched_tokens: - break - seq_len = seq_len * 2 + logger.info("Compiling the model with different input shapes.") - end = time.time() - logger.info(" -- Compilation for prefill done in %.2f [secs].", - end - start) - - # Prefix prefill - if self.scheduler_config.enable_chunked_prefill: - logger.info("Compiling the model with different input shapes for " - "prefix prefill:") - start = time.time() - for batch_size in [1]: - seq_len = 16 - while seq_len <= self.model_config.max_model_len: - self.dummy_run(self.kv_caches, - batch_size, - seq_len, - exec_mode=ExecutionMode.PREFIX_PREFILL) - xm.wait_device_ops() - logger.info(" batch_size: %d, seq_len: %d", batch_size, - seq_len) - num_tokens = batch_size * seq_len - if (num_tokens - >= self.scheduler_config.max_num_batched_tokens): - break - seq_len = seq_len * 2 - end = time.time() - logger.info( - " -- Compilation for prefix prefill done in %.2f [secs].", - end - start) - - # Decode - logger.info( - "Compiling the model with different input shapes for decode:") - start = time.time() - seq_len = 1 - batch_size = 8 # Must be in sync with _get_padded_batch_size() + start = time.perf_counter() + num_tokens = 16 while True: - self.dummy_run(self.kv_caches, - batch_size, - seq_len, - exec_mode=ExecutionMode.DECODE) + self.dummy_run(self.kv_caches, num_tokens) + logger.info(" -- num_tokens: %d", num_tokens) + xm.mark_step() xm.wait_device_ops() - logger.info(" batch_size: %d, seq_len: %d", batch_size, seq_len) - - if batch_size >= self.scheduler_config.max_num_seqs: + if num_tokens >= self.scheduler_config.max_num_batched_tokens: break - batch_size = batch_size + 16 if batch_size >= 16 else batch_size * 2 - - end = time.time() - logger.info(" -- Compilation for decode done in %.2f [secs].", - end - start) + num_tokens *= 2 + end = time.perf_counter() + logger.info("Compilation finished in in %.2f [secs].", end - start) def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: """ @@ -965,12 +656,8 @@ class ModelWrapperV1(nn.Module): """Executes the forward pass of the model and samples the next token. Args: - token_ids: The input token IDs of shape [batch_size, seq_len]. - position_ids: The input position IDs of shape [batch_size, seq_len]. - input_lens: The actual input lengths of shape [batch_size]. - t: The sampling temperature of shape [batch_size]. - p: The top-p probability of shape [batch_size]. - num_samples: Number of samples to draw from each logits vector. + token_ids: The input token IDs of shape [num_tokens]. + position_ids: The input position IDs of shape [num_tokens]. kv_caches: The key and value caches. They can be None during the memory profiling at initialization. """ @@ -982,6 +669,7 @@ class ModelWrapperV1(nn.Module): # [num_kv_heads, num_blocks, block_size, head_size]. To make it # work, we need to flatten the first three dimensions and modify # the slot_mapping accordingly. + # kv_caches: List[Tuple[torch.Tensor, torch.Tensor]] num_kv_heads, num_blocks, block_size, _ = kv_caches[0][0].shape slot_mapping = attn_metadata.slot_mapping slot_mapping = slot_mapping.flatten() @@ -997,103 +685,22 @@ class ModelWrapperV1(nn.Module): attn_metadata.slot_mapping = slot_mapping assert self.model is not None - hidden_states = self.model(token_ids, position_ids) + hidden_states = self.model( + token_ids, + position_ids, + kv_caches, + ) - hidden_states = hidden_states.flatten(0, 1) - logits = self.model.compute_logits(hidden_states, None) + return hidden_states - # Greedy sampling. - argmax_token_ids = torch.argmax(logits, dim=-1, keepdim=True) - argmax_token_ids = argmax_token_ids.squeeze(dim=-1) - return argmax_token_ids + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata, + ) -> Optional[torch.Tensor]: + logits = self.model.compute_logits(hidden_states, sampling_metadata) + return logits -def swap_positions(b: InputBatch, id_1, id_2): - assert id_1 != id_2 - req_id_1 = b.req_ids[id_1] - req_id_2 = b.req_ids[id_2] - assert req_id_1 is not None - assert req_id_2 is not None - assert id_1 == b.req_id_to_index[req_id_1] - assert id_2 == b.req_id_to_index[req_id_2] - - b.req_ids[id_1], b.req_ids[id_2] = b.req_ids[id_2], b.req_ids[id_1] - b.req_id_to_index[req_id_1], b.req_id_to_index[ - req_id_2] = b.req_id_to_index[req_id_2], b.req_id_to_index[req_id_1] - - ids = [id_1, id_2] - rev_ids = [id_2, id_1] - b.num_tokens[ids] = b.num_tokens[rev_ids] - b.token_ids_cpu[ids] = b.token_ids_cpu[rev_ids] - b.num_prompt_tokens[ids] = b.num_prompt_tokens[rev_ids] - b.num_computed_tokens_cpu[ids] = b.num_computed_tokens_cpu[rev_ids] - - b.block_table.swap_row(id_1, id_2) - - b.temperature_cpu[ids] = b.temperature_cpu[rev_ids] - b.top_p_cpu[ids] = b.top_p_cpu[rev_ids] - b.top_k_cpu[ids] = b.top_k_cpu[rev_ids] - b.frequency_penalties_cpu[ids] = b.frequency_penalties_cpu[rev_ids] - b.presence_penalties_cpu[ids] = b.presence_penalties_cpu[rev_ids] - b.repetition_penalties_cpu[ids] = b.repetition_penalties_cpu[rev_ids] - - b.min_tokens[id_1], b.min_tokens[id_2] = b.min_tokens[id_2], b.min_tokens[ - id_1] - - gen_1 = b.generators.pop(id_1, None) - gen_2 = b.generators.pop(id_2, None) - if gen_1 is not None: - b.generators[id_2] = gen_1 - if gen_2 is not None: - b.generators[id_1] = gen_2 - - -def ensure_decodes_first(b: InputBatch): - num_reqs = b.num_reqs - while True: - # Find the first prompt index - first_prompt_index = None - for i in range(num_reqs): - if b.num_computed_tokens_cpu[i] < b.num_prompt_tokens[i]: - first_prompt_index = i - break - if first_prompt_index is None: - break - - # Find the last decode index - last_decode_index = None - for i in reversed(range(num_reqs)): - if b.num_computed_tokens_cpu[i] >= b.num_prompt_tokens[i]: - last_decode_index = i - break - if last_decode_index is None: - break - - # Sanity - assert first_prompt_index != last_decode_index - - # Check if done - if first_prompt_index > last_decode_index: - break - - # Swap - swap_positions(b, first_prompt_index, last_decode_index) - - -def _get_padded_prompt_len(x: int) -> int: - # NOTE(woosuk): The pallas FlashAttention kernel requires the sequence - # length to be a multiple of 16. We pad the prompt length to the nearest - # multiple of 16. This is also good for performance. - if x <= 16: - return 16 - return 1 << (x - 1).bit_length() - - -def _get_padded_batch_size(batch_size: int) -> int: - # The GMM Pallas kernel requires num_tokens * topk to be a multiple of 16. - # To meet this requirement in the simplest way, we set the minimal batch - # size to 8. - if batch_size <= 8: - return 8 - else: - return ((batch_size + 15) // 16) * 16 +def _get_padded_number(n: int, multiple: int) -> int: + return ((n + multiple - 1) // multiple) * multiple diff --git a/vllm/v1/worker/tpu_worker.py b/vllm/v1/worker/tpu_worker.py index c236f263eddb..405dc628ee1c 100644 --- a/vllm/v1/worker/tpu_worker.py +++ b/vllm/v1/worker/tpu_worker.py @@ -21,7 +21,7 @@ from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, KVCacheSpec) from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.utils import bind_kv_cache -from vllm.v1.worker.tpu_model_runner import ExecutionMode, TPUModelRunner +from vllm.v1.worker.tpu_model_runner import TPUModelRunner logger = init_logger(__name__) @@ -126,9 +126,7 @@ class TPUWorker: self.model_runner.dummy_run( runner_kv_caches, - num_tokens=1, - seq_len=self.scheduler_config.max_num_batched_tokens, - exec_mode=ExecutionMode.PREFILL, + num_tokens=self.scheduler_config.max_num_batched_tokens, ) # Synchronize before measuring the memory usage.