diff --git a/tests/entrypoints/llm/test_accuracy.py b/tests/entrypoints/llm/test_accuracy.py index 29ff00df6d50..620355923b47 100644 --- a/tests/entrypoints/llm/test_accuracy.py +++ b/tests/entrypoints/llm/test_accuracy.py @@ -21,10 +21,13 @@ RTOL = 0.03 EXPECTED_VALUE = 0.58 -def run_test(): +def run_test(more_args=None): """Run the end to end accuracy test.""" - model_args = f"pretrained={MODEL_NAME},max_model_len=2048" + model_args = f"pretrained={MODEL_NAME},max_model_len=4096" + + if more_args is not None: + model_args = "{},{}".format(model_args, more_args) results = lm_eval.simple_evaluate( model="vllm", @@ -39,14 +42,21 @@ def run_test(): ), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}" -@pytest.mark.skipif(not current_platform.is_cuda(), - reason="V1 is currently only supported on CUDA.") +@pytest.mark.skipif(not current_platform.is_cuda() + and not current_platform.is_tpu(), + reason="V1 is currently only supported on CUDA and TPU") def test_lm_eval_accuracy_v1_engine(monkeypatch): """Run with the V1 Engine.""" with monkeypatch.context() as m: m.setenv("VLLM_USE_V1", "1") - run_test() + + more_args = None + if current_platform.is_tpu(): + # Limit compilation time for TPU V1 + more_args = "max_num_seqs=64" + + run_test(more_args) def test_lm_eval_accuracy_v0_engine(monkeypatch): diff --git a/tests/entrypoints/openai/correctness/test_lmeval.py b/tests/entrypoints/openai/correctness/test_lmeval.py index ebb2ea4d9d14..902df929e782 100644 --- a/tests/entrypoints/openai/correctness/test_lmeval.py +++ b/tests/entrypoints/openai/correctness/test_lmeval.py @@ -21,7 +21,7 @@ TASK = "gsm8k" FILTER = "exact_match,strict-match" RTOL = 0.03 EXPECTED_VALUE = 0.58 -DEFAULT_ARGS = ["--max-model-len", "2048", "--disable-log-requests"] +DEFAULT_ARGS = ["--max-model-len", "4096", "--disable-log-requests"] MORE_ARGS_LIST = [ [], # Default ["--enable-chunked-prefill"], # Chunked @@ -67,14 +67,21 @@ def run_test(more_args): ), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}" -@pytest.mark.skipif(not current_platform.is_cuda(), - reason="V1 currently only supported on CUDA") +@pytest.mark.skipif(not current_platform.is_cuda() + and not current_platform.is_tpu(), + reason="V1 currently only supported on CUDA and TPU") def test_lm_eval_accuracy_v1_engine(monkeypatch): """Run with the V1 Engine.""" with monkeypatch.context() as m: m.setenv("VLLM_USE_V1", "1") - run_test([]) + more_args = [] + + # Limit compilation time for V1 + if current_platform.is_tpu(): + more_args = ["--max-num-seqs", "64"] + + run_test(more_args) @pytest.mark.parametrize("more_args", MORE_ARGS_LIST) diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index 61673b08543f..19adc2af8c67 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -37,6 +37,7 @@ class _Backend(enum.Enum): TRITON_MLA = enum.auto() HPU_ATTN = enum.auto() PALLAS = enum.auto() + PALLAS_VLLM_V1 = enum.auto() IPEX = enum.auto() BLOCK_SPARSE_FLASH_ATTN = enum.auto() NO_ATTENTION = enum.auto() diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py index fffc61bbaaca..0c81d6a9389b 100644 --- a/vllm/platforms/tpu.py +++ b/vllm/platforms/tpu.py @@ -4,6 +4,7 @@ from typing import TYPE_CHECKING, Optional import torch +import vllm.envs as envs from vllm.logger import init_logger from .interface import Platform, PlatformEnum, _Backend @@ -33,14 +34,20 @@ class TpuPlatform(Platform): dtype: torch.dtype, kv_cache_dtype: Optional[str], block_size: int, use_v1: bool, use_mla: bool) -> str: - if selected_backend != _Backend.PALLAS: + if (selected_backend != _Backend.PALLAS + and selected_backend != _Backend.PALLAS_VLLM_V1): logger.info("Cannot use %s backend on TPU.", selected_backend) - logger.info("Using Pallas backend.") - return "vllm.attention.backends.pallas.PallasAttentionBackend" + + if use_v1: + logger.info("Using Pallas V1 backend.") + return "vllm.v1.attention.backends.pallas.PallasAttentionBackend" + else: + logger.info("Using Pallas backend.") + return "vllm.attention.backends.pallas.PallasAttentionBackend" @classmethod def get_device_name(cls, device_id: int = 0) -> str: - raise NotImplementedError + return "tpu" @classmethod def get_device_total_memory(cls, device_id: int = 0) -> int: @@ -48,7 +55,7 @@ class TpuPlatform(Platform): @classmethod def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool: - return True + return not envs.VLLM_USE_V1 @classmethod def inference_mode(cls): @@ -63,11 +70,11 @@ class TpuPlatform(Platform): cache_config.block_size = 16 compilation_config = vllm_config.compilation_config - if compilation_config.level == CompilationLevel.NO_COMPILATION: - # TPU does not support NO_COMPILATION + + # TPU only supports DYNAMO_ONCE compilation level + if compilation_config.level != CompilationLevel.DYNAMO_ONCE: + logger.info("[TPU] Forcing DYNAMO_ONCE compilation level") compilation_config.level = CompilationLevel.DYNAMO_ONCE - assert compilation_config.level < CompilationLevel.PIECEWISE,\ - "TPU does not support Inductor." if compilation_config.backend == "": compilation_config.backend = "openxla" @@ -75,10 +82,6 @@ class TpuPlatform(Platform): assert vllm_config.speculative_config is None, \ "TPU does not support speculative decoding" - assert not vllm_config.scheduler_config.chunked_prefill_enabled, ( - "Chunked prefill is not yet supported for TPU backend") - assert not vllm_config.speculative_config, ( - "Speculative decoding is not yet supported for TPU backend") if vllm_config.model_config.dtype in (torch.float16, torch.float32): logger.warning( "The TPU backend currently does not support %s. " @@ -88,8 +91,27 @@ class TpuPlatform(Platform): parallel_config = vllm_config.parallel_config scheduler_config = vllm_config.scheduler_config if parallel_config.worker_cls == "auto": - if scheduler_config.is_multi_step: + if envs.VLLM_USE_V1: parallel_config.worker_cls = \ - "vllm.worker.multi_step_tpu_worker.MultiStepTPUWorker" + "vllm.v1.worker.tpu_worker.TPUWorker" else: - parallel_config.worker_cls = "vllm.worker.tpu_worker.TPUWorker" + if scheduler_config.is_multi_step: + parallel_config.worker_cls = \ + "vllm.worker.multi_step_tpu_worker.MultiStepTPUWorker" + else: + parallel_config.worker_cls = \ + "vllm.worker.tpu_worker.TPUWorker" + + # Adjust scheduler config for V1 + # TODO: Add support for these + if envs.VLLM_USE_V1 and vllm_config.cache_config.enable_prefix_caching: + logger.warning("[V1][TPU] Disable prefix caching") + vllm_config.cache_config.enable_prefix_caching = False + + assert not vllm_config.speculative_config, ( + "Speculative decoding is not yet supported for TPU backend") + + @classmethod + def is_pin_memory_available(cls): + logger.warning("Pin memory is not supported on TPU.") + return False diff --git a/vllm/v1/attention/backends/pallas.py b/vllm/v1/attention/backends/pallas.py new file mode 100644 index 000000000000..37bf33f6e3e9 --- /dev/null +++ b/vllm/v1/attention/backends/pallas.py @@ -0,0 +1,353 @@ +# SPDX-License-Identifier: Apache-2.0 + +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Type + +import torch +import torch_xla.experimental.custom_kernel # Required to register custom ops. + +from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, + AttentionLayer, + AttentionMetadata, AttentionType) +from vllm.attention.backends.utils import CommonAttentionState + + +class PallasAttentionBackend(AttentionBackend): + + @staticmethod + def get_name() -> str: + return "PALLAS_VLLM_V1" + + @staticmethod + def get_impl_cls() -> Type["PallasAttentionBackendImpl"]: + return PallasAttentionBackendImpl + + @staticmethod + def get_metadata_cls() -> Type["PallasMetadata"]: + return PallasMetadata + + @staticmethod + def get_state_cls() -> Type["CommonAttentionState"]: + return CommonAttentionState + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + ) -> Tuple[int, ...]: + return (num_kv_heads, num_blocks, block_size, head_size) + + @staticmethod + def swap_blocks( + src_kv_cache: torch.Tensor, + dst_kv_cache: torch.Tensor, + src_to_dst: torch.Tensor, + ) -> None: + raise RuntimeError("swap_blocks is not used for the TPU backend.") + + @torch.compile(backend="openxla") + @staticmethod + def copy_blocks( + kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], + src_to_dists: Tuple[torch.Tensor, torch.Tensor], + ) -> None: + src_indices, dst_indices = src_to_dists + for k_cache, v_cache in kv_caches: + torch.ops.xla.dynamo_set_buffer_donor_(k_cache, True) + k_cache[:, dst_indices] = k_cache[:, src_indices] + torch.ops.xla.dynamo_set_buffer_donor_(v_cache, True) + v_cache[:, dst_indices] = v_cache[:, src_indices] + + +@dataclass +class PallasMetadata(AttentionMetadata): + + # Currently, input sequences can only contain all prefills + # or all decoding. + block_tables: Optional[torch.Tensor] = None + context_lens: Optional[torch.Tensor] = None + effective_query_lens: Optional[torch.Tensor] = None + + @property + def prefill_metadata(self) -> Optional["PallasMetadata"]: + if self.num_prefills == 0: + return None + + assert self.num_decode_tokens == 0 + return self + + @property + def decode_metadata(self) -> Optional["PallasMetadata"]: + if self.num_decode_tokens == 0: + return None + + assert self.num_prefills == 0 + assert self.num_prefill_tokens == 0 + assert self.block_tables is not None + assert self.context_lens is not None + return self + + +class PallasAttentionBackendImpl(AttentionImpl): + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[List[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + blocksparse_params: Optional[Dict[str, Any]] = None, + logits_soft_cap: Optional[float] = None, + attn_type: str = AttentionType.DECODER, + ) -> None: + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads + + assert self.num_heads % self.num_kv_heads == 0 + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + if head_size % 128 != 0: + raise NotImplementedError("Head size must be a multiple of 128.") + if alibi_slopes is not None: + raise NotImplementedError("Alibi slopes is not supported.") + if sliding_window is not None: + raise NotImplementedError("Sliding window is not supported.") + if kv_cache_dtype != "auto": + raise NotImplementedError("FP8 KV cache dtype is not supported.") + if blocksparse_params is not None: + raise NotImplementedError("Blocksparse is not supported.") + if logits_soft_cap is not None: + raise NotImplementedError( + "Attention logits soft-capping is not supported.") + + if torch_xla.tpu.version() < 4: + raise NotImplementedError("TPU version must be 4 or higher.") + + self.megacore_mode = None + tpu_env = torch_xla.tpu.get_tpu_env() + tpu_type = (tpu_env.get("ACCELERATOR_TYPE", None) + or tpu_env.get("TYPE", None) + or tpu_env.get("TPU_ACCELERATOR_TYPE", None)) + assert tpu_type is not None + tpu_type = tpu_type.lower() + + if (("lite" not in tpu_type) and ("v6" not in tpu_type)): + if self.num_kv_heads % 2 == 0: + self.megacore_mode = "kv_head" + else: + # NOTE(woosuk): If the batch size is not a multiple of 2, the + # megacore mode will be None. + self.megacore_mode = "batch" + + if attn_type != AttentionType.DECODER: + raise NotImplementedError("Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "PallasAttentionBackendImpl") + + def forward( + self, + layer: AttentionLayer, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: Tuple[torch.Tensor, torch.Tensor], + attn_metadata: PallasMetadata, + output: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Forward pass with Pallas attention. + + Args: + query: shape = [batch_size, seq_len, num_heads * head_size] + key: shape = [batch_size, seq_len, num_kv_heads * head_size] + value: shape = [batch_size, seq_len, num_kv_heads * head_size] + kv_cache[0] = [num_kv_heads, num_blocks, block_size, head_size] + kv_cache[1] = [num_kv_heads, num_blocks, block_size, head_size] + NOTE: kv_cache[0] and kv_cache[1] will be an empty tensor + with shape [0] for profiling run. + attn_metadata: Metadata for attention. + Returns: + shape = [batch_size, seq_len, num_heads * head_size] + """ + + if attn_metadata is None: + if output is None: + output = torch.ones_like(query) + return output + + assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0 + batch_size, seq_len, hidden_size = query.shape + query = query.view(batch_size, seq_len, self.num_heads, self.head_size) + key = key.view(batch_size, seq_len, self.num_kv_heads, self.head_size) + value = value.view(batch_size, seq_len, self.num_kv_heads, + self.head_size) + + if kv_cache[0].numel() > 0: + slot_mapping = attn_metadata.slot_mapping + key_cache, value_cache = kv_cache + write_to_kv_cache(key, value, key_cache, value_cache, slot_mapping) + + query = query * self.scale + if attn_metadata.num_prefills > 0: + if attn_metadata.block_tables is None: + # Prefill without paged KV cache. + assert seq_len % 16 == 0, ( + "Pallas FlashAttention kernel requires seq_len to be a " + f"multiple of 16 but got {seq_len}") + + # Handle GQA/MQA. + if self.num_kv_heads != self.num_heads: + key = key.repeat_interleave(self.num_queries_per_kv, + dim=-2) + key = key.view(batch_size, seq_len, self.num_heads, + self.head_size) + value = value.repeat_interleave(self.num_queries_per_kv, + dim=-2) + value = value.view(batch_size, seq_len, self.num_heads, + self.head_size) + # FlashAttention kernel requires the input shape to be + # [batch_size, num_heads, seq_len, d_model] + # while the input is [batch_size, seq_len, num_heads, d_model]. + # Permute the input to match the required format. + output = torch.ops.xla.flash_attention( + query.permute(0, 2, 1, 3), + key.permute(0, 2, 1, 3), + value.permute(0, 2, 1, 3), + True, + ) + output = output.permute(0, 2, 1, 3) + else: + # Prefill with paged KV cache. + # TODO(woosuk): Tune the below knobs. + num_kv_pages_per_compute_block = 16 + num_queries_per_compute_block = 16 + assert seq_len % num_queries_per_compute_block == 0 + output = torch.ops.xla.multi_queries_paged_attention( + query, + key_cache, + value_cache, + attn_metadata.context_lens, + attn_metadata.block_tables, + attn_metadata.effective_query_lens, + num_kv_pages_per_compute_block, + num_queries_per_compute_block, + use_kernel=True, + ) + else: + # Decoding run. + assert kv_cache[0].numel() > 0 + query = query.squeeze(dim=1) + pages_per_compute_block = 16 # TODO(woosuk): Tune this value. + + assert attn_metadata.block_tables is not None + assert attn_metadata.context_lens is not None + # NOTE(woosuk): The PagedAttention Pallas kernel stores the entire + # block table in SMEM. Therefore, if the block table is too large, + # the kernel compilation will fail. To avoid this, we split the + # batch dimension into smaller chunks and run the kernel multiple + # times. + MAX_SMEM_USAGE = 512 * 1024 + size_per_seq = 4 * attn_metadata.block_tables.shape[1] + max_num_seq = MAX_SMEM_USAGE // size_per_seq + + if batch_size <= max_num_seq: + output = paged_attention( + query, + key_cache, + value_cache, + attn_metadata.context_lens, + attn_metadata.block_tables, + pages_per_compute_block, + self.megacore_mode, + ) + else: + chunk_size = max_num_seq + # Make sure the chunk size is a multiple of 2. + chunk_size = chunk_size // 2 * 2 + num_chunks = (batch_size + chunk_size - 1) // chunk_size + + output = torch.empty_like(query) + for chunk_idx in range(num_chunks): + chunk_start = chunk_idx * chunk_size + chunk_end = chunk_start + chunk_size + # NOTE(woosuk): We skip this line because it causes Dynamo + # compilation error. Instead, we rely on the slice operation + # to handle the out-of-bound case. + # chunk_end = min(chunk_end, batch_size) + chunk_output = paged_attention( + query[chunk_start:chunk_end], + key_cache, + value_cache, + attn_metadata.context_lens[chunk_start:chunk_end], + attn_metadata.block_tables[chunk_start:chunk_end], + pages_per_compute_block, + self.megacore_mode, + ) + output[chunk_start:chunk_end] = chunk_output + + # Reshape the output tensor. + return output.reshape(batch_size, seq_len, hidden_size) + + +def write_to_kv_cache( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, +) -> None: + torch.ops.xla.dynamo_set_buffer_donor_(key_cache, True) + torch.ops.xla.dynamo_set_buffer_donor_(value_cache, True) + + key = key.flatten(0, 2) + value = value.flatten(0, 2) + key_cache = key_cache.flatten(0, 2) + value_cache = value_cache.flatten(0, 2) + key_cache.index_copy_(0, slot_mapping, key) + value_cache.index_copy_(0, slot_mapping, value) + + +def paged_attention( + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + context_lens: torch.Tensor, + block_tables: torch.Tensor, + pages_per_compute_block: int, + megacore_mode: Optional[str], +) -> torch.Tensor: + batch_size = query.shape[0] + if megacore_mode == "batch" and batch_size % 2 != 0: + megacore_mode = None + else: + megacore_mode = megacore_mode + + # NOTE(woosuk): A temporary workaround to avoid the error: + # "xla::paged_attention() Expected a value of type 'str' for + # argument 'megacore_mode' but instead found type 'NoneType'." + if megacore_mode is not None: + output = torch.ops.xla.paged_attention( + query, + key_cache, + value_cache, + context_lens, + block_tables, + pages_per_compute_block, + megacore_mode=megacore_mode, + ) + else: + output = torch.ops.xla.paged_attention( + query, + key_cache, + value_cache, + context_lens, + block_tables, + pages_per_compute_block, + ) + return output diff --git a/vllm/v1/worker/block_table.py b/vllm/v1/worker/block_table.py index f520ee9586c5..669175f5d9c3 100644 --- a/vllm/v1/worker/block_table.py +++ b/vllm/v1/worker/block_table.py @@ -61,6 +61,14 @@ class BlockTable: src, :num_blocks] self.num_blocks_per_row[tgt] = num_blocks + def swap_row(self, src: int, tgt: int) -> None: + num_blocks_src = self.num_blocks_per_row[src] + num_blocks_tgt = self.num_blocks_per_row[tgt] + self.num_blocks_per_row[src] = num_blocks_tgt + self.num_blocks_per_row[tgt] = num_blocks_src + + self.block_table_np[[src, tgt]] = self.block_table_np[[tgt, src]] + def commit(self, num_reqs: int) -> None: self.block_table[:num_reqs].copy_(self.block_table_cpu[:num_reqs], non_blocking=True) diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py new file mode 100644 index 000000000000..b64581bf5f42 --- /dev/null +++ b/vllm/v1/worker/tpu_model_runner.py @@ -0,0 +1,1109 @@ +# SPDX-License-Identifier: Apache-2.0 +import enum +import time +from dataclasses import dataclass +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple +from unittest.mock import patch + +import numpy as np +import torch +import torch.distributed +import torch.nn as nn +# TPU XLA related +import torch_xla.core.xla_model as xm +import torch_xla.runtime as xr + +from vllm.attention import AttentionMetadata +from vllm.attention.backends.abstract import AttentionType +from vllm.attention.layer import Attention +from vllm.config import VllmConfig +from vllm.forward_context import set_forward_context +from vllm.logger import init_logger +from vllm.model_executor.model_loader import get_model +from vllm.sampling_params import SamplingType +from vllm.utils import LayerBlockType, cdiv, is_pin_memory_available +from vllm.v1.attention.backends.pallas import (PallasAttentionBackend, + PallasMetadata) +from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, + KVCacheSpec) +from vllm.v1.outputs import LogprobsTensors, ModelRunnerOutput +from vllm.v1.utils import bind_kv_cache +from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch + +if TYPE_CHECKING: + from vllm.v1.core.scheduler import SchedulerOutput + +logger = init_logger(__name__) + +# Here we utilize the behavior that out-of-bound index is ignored. +# FIXME(woosuk): Find a more reliable way to prevent possible bugs. +_PAD_SLOT_ID = 1_000_000_000 + + +class ExecutionMode(enum.Enum): + PREFILL = enum.auto() + DECODE = enum.auto() + PREFIX_PREFILL = enum.auto() + + def is_prefill(self) -> bool: + return self in (ExecutionMode.PREFILL, ExecutionMode.PREFIX_PREFILL) + + +@dataclass +class PromptDecodeInfo: + prompt_req_ids: List[str] + decode_req_ids: List[str] + prompt_scheduled_tokens: List[int] + + +@dataclass +class PromptData: + input_tokens: torch.Tensor + input_positions: torch.Tensor + attn_metadata: PallasMetadata + + +@dataclass +class DecodeData: + input_tokens: Optional[torch.Tensor] = None + input_positions: Optional[torch.Tensor] = None + attn_metadata: Optional[PallasMetadata] = None + + +class TPUModelRunner: + + def __init__( + self, + vllm_config: VllmConfig, + device: torch.device, + ): + self.vllm_config = vllm_config + self.model_config = vllm_config.model_config + self.cache_config = vllm_config.cache_config + self.lora_config = vllm_config.lora_config + self.load_config = vllm_config.load_config + self.parallel_config = vllm_config.parallel_config + self.scheduler_config = vllm_config.scheduler_config + self.speculative_config = vllm_config.speculative_config + self.prompt_adapter_config = vllm_config.prompt_adapter_config + self.observability_config = vllm_config.observability_config + self.device_config = vllm_config.device_config + + model_config = self.model_config + cache_config = self.cache_config + scheduler_config = self.scheduler_config + parallel_config = self.parallel_config + self.device = device + self.pin_memory = is_pin_memory_available() + self.dtype = self.model_config.dtype + + self.is_multimodal_model = model_config.is_multimodal_model + self.sliding_window = model_config.get_sliding_window() + self.block_size = cache_config.block_size + self.max_model_len = model_config.max_model_len + self.max_num_blocks_per_req = cdiv(self.max_model_len, self.block_size) + self.max_num_tokens = scheduler_config.max_num_batched_tokens + self.max_num_reqs = scheduler_config.max_num_seqs + + # Model-related. + self.num_attn_layers = model_config.get_num_layers_by_block_type( + parallel_config, LayerBlockType.attention) + self.num_query_heads = model_config.get_num_attention_heads( + parallel_config) + self.num_kv_heads = model_config.get_num_kv_heads(parallel_config) + self.head_size = model_config.get_head_size() + self.hidden_size = model_config.get_hidden_size() + + self.model: Optional[nn.Module] = None + + # Persistent batch. + self.input_batch = InputBatch( + max_num_reqs=self.max_num_reqs, + max_model_len=self.max_model_len, + max_num_blocks_per_req=self.max_num_blocks_per_req, + device=self.device, + pin_memory=self.pin_memory, + vocab_size=self.model_config.get_vocab_size(), + ) + + # Request states. + self.requests: Dict[str, CachedRequestState] = {} + + # req_id -> (input_id -> encoder_output) + self.encoder_cache: Dict[str, Dict[int, torch.Tensor]] = {} + + # KV caches for forward pass + self.kv_caches: List[Tuple[torch.Tensor, torch.Tensor]] = [] + + # Cached torch/numpy tensors + self.num_swaps = 2 + self.cur_swap_id = 0 + self.input_ids_cpu = [] + self.input_ids_np = [] + self.input_positions_cpu = [] + self.input_positions_np = [] + self.slot_mapping_cpu = [] + self.slot_mapping_np = [] + self.prompt_context_lens_cpu = [] + self.prompt_effective_query_lens_cpu = [] + self.decode_context_lens_cpu = [] + self.decode_context_lens_np = [] + for _ in range(self.num_swaps): + self.input_ids_cpu.append( + torch.empty(self.max_num_tokens, + dtype=torch.int32, + device="cpu")) + self.input_ids_np.append(self.input_ids_cpu[-1].numpy()) + + self.input_positions_cpu.append( + torch.empty(self.max_num_tokens, + dtype=torch.int32, + device="cpu")) + self.input_positions_np.append( + self.input_positions_cpu[-1].numpy()) + + self.slot_mapping_cpu.append( + torch.empty(self.max_num_tokens, + dtype=torch.int64, + device="cpu")) + self.slot_mapping_np.append(self.slot_mapping_cpu[-1].numpy()) + + self.prompt_context_lens_cpu.append( + torch.empty((1), dtype=torch.int32, device="cpu")) + self.prompt_effective_query_lens_cpu.append( + torch.empty((1), dtype=torch.int32, device="cpu")) + + self.decode_context_lens_cpu.append( + torch.empty(self.max_num_tokens, + dtype=torch.int32, + device="cpu")) + self.decode_context_lens_np.append( + self.decode_context_lens_cpu[-1].numpy()) + + # Range tensor with values [0 .. self.max_num_tokens - 1]. + # Used to initialize positions / context_lens / seq_lens + self.arange_np = np.arange(self.max_num_tokens, dtype=np.int32) + + def _update_states(self, scheduler_output: "SchedulerOutput") -> bool: + """Update the cached states and the persistent batch with the scheduler + output. + + The updated states are used by the `_prepare_inputs` function to create + the input GPU tensors for the model. + + Returns: + True if there is a new/resumed/paused/finished request in the batch. + If False, we can skip copying SamplingMetadata to the GPU. + """ + # Remove finished requests from the cached states. + for req_id in scheduler_output.finished_req_ids: + self.requests.pop(req_id, None) + + # Remove the finished requests from the persistent batch. + # NOTE(woosuk): There could be an edge case where finished_req_ids and + # scheduled_req_ids overlap. This happens when a request is aborted and + # then resubmitted with the same ID. In this case, we treat them as two + # distinct requests - clearing the cached states for the first request + # and handling the second as a new request. + removed_req_indices: List[int] = [] + for req_id in scheduler_output.finished_req_ids: + req_index = self.input_batch.remove_request(req_id) + if req_index is not None: + removed_req_indices.append(req_index) + + # Remove the unscheduled requests from the persistent batch. + # NOTE(woosuk): The unscheduled requests are either preempted requests + # or running requests that are not scheduled in this step. We remove + # them from the persistent batch but keep their cached states since + # they will be scheduled again sometime in the future. + scheduled_req_ids = scheduler_output.num_scheduled_tokens.keys() + cached_req_ids = self.input_batch.req_id_to_index.keys() + unscheduled_req_ids = cached_req_ids - scheduled_req_ids + # NOTE(woosuk): The persistent batch optimization assumes that + # consecutive batches contain mostly the same requests. If batches + # have low request overlap (e.g., alternating between two distinct + # sets of requests), this optimization becomes very inefficient. + for req_id in unscheduled_req_ids: + req_index = self.input_batch.remove_request(req_id) + assert req_index is not None + removed_req_indices.append(req_index) + + req_ids_to_add: List[str] = [] + # Add new requests to the cached states. + for new_req_data in scheduler_output.scheduled_new_reqs: + req_id = new_req_data.req_id + sampling_params = new_req_data.sampling_params + if sampling_params.sampling_type == SamplingType.RANDOM_SEED: + generator = torch.Generator(device=self.device) + generator.manual_seed(sampling_params.seed) + else: + generator = None + + self.requests[req_id] = CachedRequestState( + req_id=req_id, + prompt_token_ids=new_req_data.prompt_token_ids, + prompt=new_req_data.prompt, + mm_inputs=new_req_data.mm_inputs, + mm_positions=new_req_data.mm_positions, + sampling_params=sampling_params, + generator=generator, + block_ids=new_req_data.block_ids, + num_computed_tokens=new_req_data.num_computed_tokens, + output_token_ids=[], + lora_request=new_req_data.lora_request, + ) + + req_ids_to_add.append(req_id) + + # Update the states of the running/resumed requests. + for req_data in scheduler_output.scheduled_cached_reqs: + req_id = req_data.req_id + req_state = self.requests[req_id] + + # Update the cached states. + req_state.num_computed_tokens = req_data.num_computed_tokens + if not req_data.resumed_from_preemption: + # Append the new blocks to the existing block IDs. + req_state.block_ids.extend(req_data.new_block_ids) + else: + # The request is resumed from preemption. + # Replace the existing block IDs with the new ones. + req_state.block_ids = req_data.new_block_ids + + req_index = self.input_batch.req_id_to_index.get(req_id) + if req_index is None: + # The request is not in the persistent batch. + # The request was either preempted and resumed later, or was not + # scheduled in the previous step and needs to be added again. + req_ids_to_add.append(req_id) + continue + + # Update the persistent batch. + self.input_batch.num_computed_tokens_cpu[req_index] = ( + req_data.num_computed_tokens) + start_index = len(req_state.block_ids) - len( + req_data.new_block_ids) + self.input_batch.block_table.append_row(req_index, start_index, + req_data.new_block_ids) + + # Add the new or resumed requests to the persistent batch. + # The smaller empty indices are filled first. + removed_req_indices = sorted(removed_req_indices, reverse=True) + for req_id in req_ids_to_add: + req_state = self.requests[req_id] + if removed_req_indices: + # Fill the empty index. + req_index = removed_req_indices.pop() + else: + # Append to the end. + req_index = None + self.input_batch.add_request(req_state, req_index) + + # Condense the batched states if there are empty indices. + if removed_req_indices: + self.input_batch.condense(removed_req_indices) + return len(unscheduled_req_ids) > 0 or len(req_ids_to_add) > 0 + + def swap_step(self): + self.cur_swap_id = (self.cur_swap_id + 1) % self.num_swaps + + def get_model(self) -> nn.Module: + assert self.model is not None + return self.model + + def get_kv_cache_spec(self) -> KVCacheSpec: + """ + Generates the KVCacheSpec by parsing the kv cache format from each + Attention module in the static forward context. + Returns: + KVCacheSpec: A dictionary mapping layer names to their KV cache + format. Layers that do not need KV cache are not included. + """ + + forward_ctx = self.vllm_config.compilation_config.static_forward_context + block_size = self.vllm_config.cache_config.block_size + kv_cache_spec: KVCacheSpec = {} + for layer_name, attn_module in forward_ctx.items(): + # TODO: Support other attention modules, e.g., sliding window, + # cross-attention, MLA. + assert isinstance(attn_module, Attention) + if attn_module.attn_type == AttentionType.DECODER: + kv_cache_spec[layer_name] = FullAttentionSpec( + block_size=block_size, + num_kv_heads=attn_module.num_kv_heads, + head_size=attn_module.head_size, + dtype=attn_module.dtype, + ) + elif attn_module.attn_type in (AttentionType.ENCODER, + AttentionType.ENCODER_ONLY): + # encoder-only attention does not need KV cache. + continue + elif attn_module.attn_type == AttentionType.ENCODER_DECODER: + raise NotImplementedError + else: + raise ValueError( + f"Unknown attention type: {attn_module.attn_type}") + + return kv_cache_spec + + def _get_prompts_and_decodes( + self, + scheduler_output: "SchedulerOutput", + ) -> PromptDecodeInfo: + total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens + assert total_num_scheduled_tokens > 0 + num_reqs = self.input_batch.num_reqs + assert num_reqs > 0 + + # Traverse decodes first + decode_req_ids = [] + for i in range(num_reqs): + req_id = self.input_batch.req_ids[i] + assert req_id is not None + + num_computed_tokens = self.input_batch.num_computed_tokens_cpu[i] + num_prompt_tokens = self.input_batch.num_prompt_tokens[i] + num_scheduled_tokens = scheduler_output.num_scheduled_tokens[ + req_id] + + if num_computed_tokens < num_prompt_tokens: + # This is prompt + break + + # This is decode + assert num_scheduled_tokens == 1 + decode_req_ids.append(req_id) + + # Traverse prompts + prompt_req_ids = [] + prompt_scheduled_tokens = [] + for i in range(len(decode_req_ids), num_reqs): + req_id = self.input_batch.req_ids[i] + assert req_id is not None + + num_computed_tokens = self.input_batch.num_computed_tokens_cpu[i] + num_prompt_tokens = self.input_batch.num_prompt_tokens[i] + num_scheduled_tokens = scheduler_output.num_scheduled_tokens[ + req_id] + + # Must be prompt + assert num_computed_tokens < num_prompt_tokens + + prompt_req_ids.append(req_id) + prompt_scheduled_tokens.append(num_scheduled_tokens) + + return PromptDecodeInfo(prompt_req_ids, decode_req_ids, + prompt_scheduled_tokens) + + def _prepare_prompt(self, req_index: int, + num_scheduled_tokens: int) -> PromptData: + num_computed_tokens = self.input_batch.num_computed_tokens_cpu[ + req_index] + num_prompt_tokens = self.input_batch.num_prompt_tokens[req_index] + + # Must be prompt + assert num_computed_tokens < num_prompt_tokens + + # Prompt len + prompt_len = num_scheduled_tokens + padded_prompt_len = _get_padded_prompt_len(prompt_len) + assert padded_prompt_len <= self.max_model_len + + # Seq len + seq_len = num_computed_tokens + prompt_len + padded_seq_len = num_computed_tokens + padded_prompt_len + + # Input tokens + input_tokens_cpu = self.input_batch.token_ids_cpu_tensor[ + req_index, num_computed_tokens:padded_seq_len] + input_tokens_cpu[prompt_len:] = 0 + + # Input positions + input_positions_np = self.input_positions_np[ + self.cur_swap_id][:padded_prompt_len] + np.add(num_computed_tokens, + self.arange_np[:padded_prompt_len], + out=input_positions_np) + input_positions_np[prompt_len:] = 0 + + # Slot mapping + block_table_np = \ + self.input_batch.block_table.get_numpy_array() + block_numbers_np = block_table_np[req_index, input_positions_np // + self.block_size] + block_offsets_np = input_positions_np % self.block_size + + slot_mapping_np = self.slot_mapping_np[ + self.cur_swap_id][:padded_prompt_len] + np.add(block_numbers_np * self.block_size, + block_offsets_np, + out=slot_mapping_np) + slot_mapping_np[prompt_len:] = _PAD_SLOT_ID + + # Block table + block_table_cpu = None + if num_computed_tokens > 0: + block_table_cpu = self.input_batch.block_table.get_cpu_tensor() + block_table_cpu = block_table_cpu[req_index] + + # Context len + self.prompt_context_lens_cpu[self.cur_swap_id][0] = 0 + if num_computed_tokens > 0: + self.prompt_context_lens_cpu[self.cur_swap_id][0] = seq_len + + # Effective query len + self.prompt_effective_query_lens_cpu[self.cur_swap_id][0] = prompt_len + + # Get final tensors + input_tokens = input_tokens_cpu.reshape(1, -1).to(self.device) + input_positions = self.input_positions_cpu[ + self.cur_swap_id][:padded_prompt_len].reshape(1, + -1).to(self.device) + slot_mapping = self.slot_mapping_cpu[ + self.cur_swap_id][:padded_prompt_len].reshape(1, + -1).to(self.device) + block_table = block_table_cpu.reshape(1, -1).to( + self.device) if block_table_cpu is not None else None + + context_lens = self.prompt_context_lens_cpu[self.cur_swap_id].to( + self.device) + effective_query_lens = self.prompt_effective_query_lens_cpu[ + self.cur_swap_id].to(self.device) + + self.swap_step() + + # Attn metadata + attn_metadata = PallasMetadata( + num_prefills=1, + num_prefill_tokens=0, # NOTE: This is not used. + num_decode_tokens=0, + slot_mapping=slot_mapping, + multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=True, + block_tables=block_table, + context_lens=context_lens, + effective_query_lens=effective_query_lens, + ) + + return PromptData(input_tokens, input_positions, attn_metadata) + + def _prepare_decode( + self, + decode_req_ids: List[str], + ) -> DecodeData: + # Batch size + batch_size = len(decode_req_ids) + padded_batch_size = _get_padded_batch_size(batch_size) + assert padded_batch_size <= self.max_model_len + + # Init [0 .. batch_size - 1] + req_indices_np = self.arange_np[:padded_batch_size] + + # Input positions + input_positions_np = self.input_positions_np[ + self.cur_swap_id][:padded_batch_size] + np.add(self.input_batch.num_computed_tokens_cpu[:padded_batch_size], + 0, + out=input_positions_np) + input_positions_np[batch_size:] = 0 + input_positions_cpu = self.input_positions_cpu[ + self.cur_swap_id][:padded_batch_size] + + # Input tokens + token_indices_np = ( + input_positions_np + + req_indices_np * self.input_batch.token_ids_cpu.shape[1]) + input_tokens_cpu = self.input_ids_cpu[ + self.cur_swap_id][:padded_batch_size] + torch.index_select(self.input_batch.token_ids_cpu_tensor.flatten(), + 0, + torch.from_numpy(token_indices_np), + out=input_tokens_cpu) + input_tokens_cpu[batch_size:] = 0 + + # Slot mapping + block_table_indices_np = ( + req_indices_np * self.max_num_blocks_per_req + + input_positions_np // self.block_size) + + block_table_cpu = self.input_batch.block_table.get_cpu_tensor() + + block_numbers_np = block_table_cpu.flatten( + )[block_table_indices_np].numpy() + + block_offsets_np = input_positions_np % self.block_size + + slot_mapping_np = self.slot_mapping_np[ + self.cur_swap_id][:padded_batch_size] + np.add(block_numbers_np * self.block_size, + block_offsets_np, + out=slot_mapping_np) + slot_mapping_np[batch_size:] = _PAD_SLOT_ID + + block_table_cpu = block_table_cpu[:padded_batch_size] + + # Context lens + context_lens_np = self.decode_context_lens_np[ + self.cur_swap_id][:padded_batch_size] + np.add(self.input_batch.num_computed_tokens_cpu[:padded_batch_size], + 1, + out=context_lens_np) + context_lens_np[batch_size:] = 0 + + # Get final tensors + input_tokens = input_tokens_cpu.reshape(-1, 1).to(self.device) + input_positions = input_positions_cpu.reshape(-1, 1).to(self.device) + slot_mapping = self.slot_mapping_cpu[ + self.cur_swap_id][:padded_batch_size].reshape(-1, + 1).to(self.device) + block_table = block_table_cpu.to(self.device) + context_lens = self.decode_context_lens_cpu[ + self.cur_swap_id][:padded_batch_size].to(self.device) + + self.swap_step() + + # Attn metadata + attn_metadata = PallasMetadata( + num_prefills=0, + num_prefill_tokens=0, + num_decode_tokens=padded_batch_size, + slot_mapping=slot_mapping, + multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=True, + block_tables=block_table, + context_lens=context_lens, + effective_query_lens=None, + ) + + return DecodeData(input_tokens=input_tokens, + input_positions=input_positions, + attn_metadata=attn_metadata) + + @torch.no_grad() + def execute_model( + self, + scheduler_output: "SchedulerOutput", + ) -> ModelRunnerOutput: + # Update cached state + self._update_states(scheduler_output) + + # If necessary, swap decodes/prompts to have all decodes on the start + ensure_decodes_first(self.input_batch) + + # Prepare prompts/decodes info + pd_info = self._get_prompts_and_decodes(scheduler_output) + + # Init + num_prompts = len(pd_info.prompt_req_ids) + num_decodes = len(pd_info.decode_req_ids) + decode_data = None + sampled_token_ids = [0] * self.input_batch.num_reqs + + # Run each prompt individually + is_first = True + for i in range(num_prompts): + req_id = pd_info.prompt_req_ids[i] + req_index = num_decodes + i + assert req_index == self.input_batch.req_id_to_index[ + req_id] # TODO: Remove + req_state = self.requests[req_id] + num_scheduled_tokens = pd_info.prompt_scheduled_tokens[i] + prompt_len = num_scheduled_tokens + seq_len = req_state.num_computed_tokens + num_scheduled_tokens + + # Prepare first prompt + if is_first: + prompt_data = self._prepare_prompt(req_index, + num_scheduled_tokens) + is_first = False + + # Run forward pass + with set_forward_context(prompt_data.attn_metadata, + self.vllm_config): + assert self.model is not None + selected_token_ids = self.model(prompt_data.input_tokens, + prompt_data.input_positions, + prompt_data.attn_metadata, + self.kv_caches) + + # In parallel to TPU execution, prepare the next iteration + if i < num_prompts - 1: + # There is next prompt => prepare it + prompt_data = self._prepare_prompt( + req_index + 1, pd_info.prompt_scheduled_tokens[i + 1]) + elif i == num_prompts - 1 and num_decodes > 0: + # There is next decode => prepare it + decode_data = self._prepare_decode(pd_info.decode_req_ids) + + # Update cached state (if prompt is fully done) + if seq_len >= len(req_state.prompt_token_ids): + # Transfer sampled tokens from TPU to CPU + selected_token_ids_cpu = selected_token_ids.cpu() + + # Get output token + token_id = selected_token_ids_cpu[prompt_len - 1].item() + sampled_token_ids[req_index] = token_id + + # Add output token to the request + self.input_batch.token_ids_cpu[req_index, seq_len] = token_id + self.input_batch.num_tokens[req_index] += 1 + req_state.output_token_ids.append(token_id) + + # Run decodes (a single batch) + if num_decodes > 0: + + # Prepare decode (if was not yet prepared) + if decode_data is None: + decode_data = self._prepare_decode(pd_info.decode_req_ids) + + # Run forward pass + with set_forward_context(decode_data.attn_metadata, + self.vllm_config): + assert self.model is not None + selected_token_ids = self.model(decode_data.input_tokens, + decode_data.input_positions, + decode_data.attn_metadata, + self.kv_caches) + + # Transfer sampled tokens from TPU to CPU + decode_token_ids_cpu = selected_token_ids.cpu() + # Convert to list + decode_token_ids_list = decode_token_ids_cpu.tolist() + + # Update cached state for each decode request + for i in range(num_decodes): + req_id = pd_info.decode_req_ids[i] + req_index = i + assert req_index == self.input_batch.req_id_to_index[ + req_id] # TODO: Remove + req_state = self.requests[req_id] + seq_len = req_state.num_computed_tokens + 1 + + token_id = decode_token_ids_list[i] + sampled_token_ids[req_index] = token_id + + self.input_batch.token_ids_cpu[req_index, seq_len] = token_id + self.input_batch.num_tokens[req_index] += 1 + req_state.output_token_ids.append(token_id) + + # Create output. + all_req_ids = pd_info.decode_req_ids + pd_info.prompt_req_ids + prompt_logprobs_dict: Dict[str, Optional[LogprobsTensors]] = {} + for req_id in all_req_ids: + prompt_logprobs_dict[req_id] = None + + model_runner_output = ModelRunnerOutput( + req_ids=all_req_ids, + req_id_to_index=self.input_batch.req_id_to_index, + sampled_token_ids=sampled_token_ids, + logprobs=None, + prompt_logprobs_dict=prompt_logprobs_dict, # type: ignore[arg-type] + ) + + return model_runner_output + + def load_model(self) -> None: + self.device = self.device_config.device + + # NOTE(woosuk): While the executor assigns the TP ranks to the worker + # process, the ranks can be different from the ranks internally assigned + # by the xm runtime. Therefore, there is a mismatch in the rank + # assignment between the gloo (cpu) runtime and the xm (tpu) runtime. + # This is not a problem in linear layers because all-reduce is + # rank-agnostic. However, it matters for all-gather as the ranks + # determine the order of concatenating the output tensors. + # As a workaround, we use the xm's rank assignment only when loading + # the embedding weights. + xm_tp_rank = xr.global_ordinal() + with patch( + "vllm.model_executor.layers.vocab_parallel_embedding." + "get_tensor_model_parallel_rank", + return_value=xm_tp_rank): + model = get_model(vllm_config=self.vllm_config) + model = model.eval() + xm.mark_step() + xm.wait_device_ops() + model = ModelWrapperV1(model) + self.model = torch.compile(model, + backend="openxla", + fullgraph=True, + dynamic=False) + + def dummy_run( + self, + kv_caches, + num_tokens: int, + seq_len: Optional[int] = None, + exec_mode: Optional[ExecutionMode] = None, + ) -> None: + assert seq_len is not None + assert exec_mode is not None + + exec_mode = ExecutionMode(exec_mode) + if exec_mode.is_prefill(): + seq_len = (seq_len + 15) // 16 * 16 + token_ids = torch.zeros((num_tokens, seq_len), + dtype=torch.int32, + device=self.device) + position_ids = torch.zeros((num_tokens, seq_len), + dtype=torch.int32, + device=self.device) + slot_mapping = torch.zeros((num_tokens, seq_len), + dtype=torch.int64, + device=self.device) + if exec_mode == ExecutionMode.PREFILL: + attn_metadata = PallasMetadata( + num_prefills=num_tokens, + num_prefill_tokens=num_tokens * seq_len, + num_decode_tokens=0, + slot_mapping=slot_mapping, + multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=True, + block_tables=None, + context_lens=None, + effective_query_lens=None, + ) + + else: + context_lens = torch.ones((num_tokens, ), + dtype=torch.int32, + device=self.device) + + block_tables = torch.zeros( + (num_tokens, self.max_num_blocks_per_req), + dtype=torch.int32, + device=self.device) + + effective_query_lens = torch.ones_like(context_lens) + + attn_metadata = PallasMetadata( + num_prefills=num_tokens, + num_prefill_tokens=num_tokens * seq_len, + num_decode_tokens=0, + slot_mapping=slot_mapping, + multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=True, + block_tables=block_tables, + context_lens=context_lens, + effective_query_lens=effective_query_lens, + ) + else: + assert seq_len == 1 + token_ids = torch.zeros((num_tokens, seq_len), + dtype=torch.int32, + device=self.device) + position_ids = torch.zeros((num_tokens, seq_len), + dtype=torch.int32, + device=self.device) + slot_mapping = torch.zeros((num_tokens, seq_len), + dtype=torch.int64, + device=self.device) + block_tables = torch.zeros( + (num_tokens, self.max_num_blocks_per_req), + dtype=torch.int32, + device=self.device) + context_lens = torch.ones((num_tokens, ), + dtype=torch.int32, + device=self.device) + attn_metadata = PallasMetadata( + num_prefills=0, + num_prefill_tokens=0, + num_decode_tokens=num_tokens * seq_len, + slot_mapping=slot_mapping, + multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=True, + block_tables=block_tables, + context_lens=context_lens, + ) + + # NOTE(woosuk): There are two stages of compilation: torch.compile and + # XLA compilation. Using `mark_dynamic` can reduce the torch.compile + # overhead by reusing the FX graph for different shapes. + # However, the XLA graph will still require static shapes and needs to + # be re-compiled for every different shapes. This overhead is inevitable + # in the first run, but can be skipped afterwards as we cache the XLA + # graphs in the disk (VLLM_XLA_CACHE_PATH). + if exec_mode.is_prefill(): + # Prefll + torch._dynamo.mark_dynamic(token_ids, 1) + torch._dynamo.mark_dynamic(position_ids, 1) + torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 1) + else: + # Decode + torch._dynamo.mark_dynamic(token_ids, 0) + torch._dynamo.mark_dynamic(position_ids, 0) + torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 0) + torch._dynamo.mark_dynamic(attn_metadata.context_lens, 0) + torch._dynamo.mark_dynamic(attn_metadata.block_tables, 0) + + with set_forward_context(attn_metadata, self.vllm_config, 0): + assert self.model is not None + self.model(token_ids, position_ids, attn_metadata, kv_caches) + + def capture_model(self) -> None: + """Compile the model.""" + + # Prefill + logger.info( + "Compiling the model with different input shapes for prefill:") + start = time.time() + for batch_size in [1]: + seq_len = 16 + while seq_len <= self.model_config.max_model_len: + self.dummy_run(self.kv_caches, + batch_size, + seq_len, + exec_mode=ExecutionMode.PREFILL) + xm.wait_device_ops() + logger.info(" batch_size: %d, seq_len: %d", batch_size, + seq_len) + num_tokens = batch_size * seq_len + if num_tokens >= self.scheduler_config.max_num_batched_tokens: + break + seq_len = seq_len * 2 + + end = time.time() + logger.info(" -- Compilation for prefill done in %.2f [secs].", + end - start) + + # Prefix prefill + if self.scheduler_config.enable_chunked_prefill: + logger.info("Compiling the model with different input shapes for " + "prefix prefill:") + start = time.time() + for batch_size in [1]: + seq_len = 16 + while seq_len <= self.model_config.max_model_len: + self.dummy_run(self.kv_caches, + batch_size, + seq_len, + exec_mode=ExecutionMode.PREFIX_PREFILL) + xm.wait_device_ops() + logger.info(" batch_size: %d, seq_len: %d", batch_size, + seq_len) + num_tokens = batch_size * seq_len + if (num_tokens + >= self.scheduler_config.max_num_batched_tokens): + break + seq_len = seq_len * 2 + end = time.time() + logger.info( + " -- Compilation for prefix prefill done in %.2f [secs].", + end - start) + + # Decode + logger.info( + "Compiling the model with different input shapes for decode:") + start = time.time() + seq_len = 1 + batch_size = 8 # Must be in sync with _get_padded_batch_size() + while True: + self.dummy_run(self.kv_caches, + batch_size, + seq_len, + exec_mode=ExecutionMode.DECODE) + xm.wait_device_ops() + logger.info(" batch_size: %d, seq_len: %d", batch_size, seq_len) + + if batch_size >= self.scheduler_config.max_num_seqs: + break + batch_size = batch_size + 16 if batch_size >= 16 else batch_size * 2 + + end = time.time() + logger.info(" -- Compilation for decode done in %.2f [secs].", + end - start) + + def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: + """ + Initialize KV cache based on `kv_cache_config`. + Args: + kv_cache_config: Configuration for the KV cache, including the KV + cache size of each layer + """ + if len(kv_cache_config.groups) > 1: + raise NotImplementedError( + "Hybrid models with more than one KV cache type are not " + "supported yet.") + + kv_caches: Dict[str, torch.Tensor] = {} + + for layer_name, layer_spec in kv_cache_config.kv_cache_spec.items(): + tensor_config = kv_cache_config.tensors[layer_name] + assert tensor_config.size % layer_spec.page_size_bytes == 0 + num_blocks = tensor_config.size // layer_spec.page_size_bytes + if isinstance(layer_spec, FullAttentionSpec): + kv_cache_shape = PallasAttentionBackend.get_kv_cache_shape( + num_blocks, layer_spec.block_size, layer_spec.num_kv_heads, + layer_spec.head_size) + dtype = layer_spec.dtype + + tpu_k_cache = torch.zeros(kv_cache_shape, + dtype=dtype, + device=self.device) + tpu_v_cache = torch.zeros_like(tpu_k_cache) + + kv_caches[layer_name] = (tpu_k_cache, tpu_v_cache) + else: + raise NotImplementedError + + bind_kv_cache( + kv_caches, + self.vllm_config.compilation_config.static_forward_context, + self.kv_caches) + + +class ModelWrapperV1(nn.Module): + + def __init__(self, model: nn.Module): + super().__init__() + self.model = model + + def forward( + self, + token_ids: torch.Tensor, + position_ids: torch.Tensor, + attn_metadata: AttentionMetadata, + kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], + ) -> torch.Tensor: + """Executes the forward pass of the model and samples the next token. + + Args: + token_ids: The input token IDs of shape [batch_size, seq_len]. + position_ids: The input position IDs of shape [batch_size, seq_len]. + attn_metadata: The Pallas attention metadata. + input_lens: The actual input lengths of shape [batch_size]. + t: The sampling temperature of shape [batch_size]. + p: The top-p probability of shape [batch_size]. + num_samples: Number of samples to draw from each logits vector. + kv_caches: The key and value caches. They can be None during the + memory profiling at initialization. + """ + # Skip this in memory profiling at initialization. + if attn_metadata is not None and kv_caches[0][0].numel() > 0: + # index_copy_(slot_mapping) only works when the inserted dimension + # is 0. However, the KV cache in the Pallas backend has the shape + # [num_kv_heads, num_blocks, block_size, head_size]. To make it + # work, we need to flatten the first three dimensions and modify + # the slot_mapping accordingly. + num_kv_heads, num_blocks, block_size, _ = kv_caches[0][0].shape + slot_mapping = attn_metadata.slot_mapping + slot_mapping = slot_mapping.flatten() + head_indicies = torch.arange(0, + num_kv_heads, + device=slot_mapping.device, + dtype=slot_mapping.dtype) + head_indicies *= block_size * num_blocks + slot_mapping = slot_mapping.repeat_interleave(num_kv_heads).view( + -1, num_kv_heads) + slot_mapping = slot_mapping + head_indicies.view(1, -1) + slot_mapping = slot_mapping.flatten() + attn_metadata.slot_mapping = slot_mapping + + assert self.model is not None + hidden_states = self.model( + token_ids, + position_ids, + kv_caches, + attn_metadata, + ) + + hidden_states = hidden_states.flatten(0, 1) + logits = self.model.compute_logits(hidden_states, None) + + # Greedy sampling. + argmax_token_ids = torch.argmax(logits, dim=-1, keepdim=True) + argmax_token_ids = argmax_token_ids.squeeze(dim=-1) + return argmax_token_ids + + +def swap_positions(b: InputBatch, id_1, id_2): + assert id_1 != id_2 + req_id_1 = b.req_ids[id_1] + req_id_2 = b.req_ids[id_2] + assert req_id_1 is not None + assert req_id_2 is not None + assert id_1 == b.req_id_to_index[req_id_1] + assert id_2 == b.req_id_to_index[req_id_2] + + b.req_ids[id_1], b.req_ids[id_2] = b.req_ids[id_2], b.req_ids[id_1] + b.req_id_to_index[req_id_1], b.req_id_to_index[ + req_id_2] = b.req_id_to_index[req_id_2], b.req_id_to_index[req_id_1] + + ids = [id_1, id_2] + rev_ids = [id_2, id_1] + b.num_tokens[ids] = b.num_tokens[rev_ids] + b.token_ids_cpu[ids] = b.token_ids_cpu[rev_ids] + b.num_prompt_tokens[ids] = b.num_prompt_tokens[rev_ids] + b.num_computed_tokens_cpu[ids] = b.num_computed_tokens_cpu[rev_ids] + + b.block_table.swap_row(id_1, id_2) + + b.temperature_cpu[ids] = b.temperature_cpu[rev_ids] + b.top_p_cpu[ids] = b.top_p_cpu[rev_ids] + b.top_k_cpu[ids] = b.top_k_cpu[rev_ids] + b.frequency_penalties_cpu[ids] = b.frequency_penalties_cpu[rev_ids] + b.presence_penalties_cpu[ids] = b.presence_penalties_cpu[rev_ids] + b.repetition_penalties_cpu[ids] = b.repetition_penalties_cpu[rev_ids] + + b.min_tokens[id_1], b.min_tokens[id_2] = b.min_tokens[id_2], b.min_tokens[ + id_1] + b.stop_token_ids[id_1], b.stop_token_ids[id_2] = b.stop_token_ids[ + id_2], b.stop_token_ids[id_1] + + gen_1 = b.generators.pop(id_1, None) + gen_2 = b.generators.pop(id_2, None) + if gen_1 is not None: + b.generators[id_2] = gen_1 + if gen_2 is not None: + b.generators[id_1] = gen_2 + + +def ensure_decodes_first(b: InputBatch): + num_reqs = b.num_reqs + while True: + # Find the first prompt index + first_prompt_index = None + for i in range(num_reqs): + if b.num_computed_tokens_cpu[i] < b.num_prompt_tokens[i]: + first_prompt_index = i + break + if first_prompt_index is None: + break + + # Find the last decode index + last_decode_index = None + for i in reversed(range(num_reqs)): + if b.num_computed_tokens_cpu[i] >= b.num_prompt_tokens[i]: + last_decode_index = i + break + if last_decode_index is None: + break + + # Sanity + assert first_prompt_index != last_decode_index + + # Check if done + if first_prompt_index > last_decode_index: + break + + # Swap + swap_positions(b, first_prompt_index, last_decode_index) + + +def _get_padded_prompt_len(x: int) -> int: + # NOTE(woosuk): The pallas FlashAttention kernel requires the sequence + # length to be a multiple of 16. We pad the prompt length to the nearest + # multiple of 16. This is also good for performance. + if x <= 16: + return 16 + return 1 << (x - 1).bit_length() + + +def _get_padded_batch_size(batch_size: int) -> int: + # The GMM Pallas kernel requires num_tokens * topk to be a multiple of 16. + # To meet this requirement in the simplest way, we set the minimal batch + # size to 8. + if batch_size <= 8: + return 8 + else: + return ((batch_size + 15) // 16) * 16 diff --git a/vllm/v1/worker/tpu_worker.py b/vllm/v1/worker/tpu_worker.py new file mode 100644 index 000000000000..f29edd34ede3 --- /dev/null +++ b/vllm/v1/worker/tpu_worker.py @@ -0,0 +1,203 @@ +# SPDX-License-Identifier: Apache-2.0 +"""A TPU worker class.""" +import os +from typing import Dict, List, Optional + +import torch +import torch.distributed +import torch.nn as nn +import torch_xla.core.xla_model as xm +import torch_xla.runtime as xr + +import vllm.envs as envs +from vllm.config import ParallelConfig, VllmConfig +from vllm.distributed import (ensure_model_parallel_initialized, + init_distributed_environment) +from vllm.logger import init_logger +from vllm.model_executor import set_random_seed +from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE +from vllm.v1.core.scheduler import SchedulerOutput +from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, + KVCacheSpec) +from vllm.v1.outputs import ModelRunnerOutput +from vllm.v1.utils import bind_kv_cache +from vllm.v1.worker.tpu_model_runner import ExecutionMode, TPUModelRunner + +logger = init_logger(__name__) + + +class TPUWorker: + + def __init__( + self, + vllm_config: VllmConfig, + local_rank: int, + rank: int, + distributed_init_method: str, + is_driver_worker: bool = False, + ): + self.vllm_config = vllm_config + self.model_config = vllm_config.model_config + self.cache_config = vllm_config.cache_config + self.lora_config = vllm_config.lora_config + self.load_config = vllm_config.load_config + self.parallel_config = vllm_config.parallel_config + self.scheduler_config = vllm_config.scheduler_config + self.device_config = vllm_config.device_config + self.speculative_config = vllm_config.speculative_config + self.prompt_adapter_config = vllm_config.prompt_adapter_config + self.observability_config = vllm_config.observability_config + + self.parallel_config.rank = rank + self.local_rank = local_rank + self.rank = rank + self.distributed_init_method = distributed_init_method + + if self.cache_config.cache_dtype == "auto": + self.cache_dtype = self.model_config.dtype + else: + self.cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[ + self.cache_config.cache_dtype] + + if self.model_config.trust_remote_code: + # note: lazy import to avoid importing torch before initializing + from vllm.utils import init_cached_hf_modules + init_cached_hf_modules() + + def init_device(self): + os.environ["PJRT_DEVICE"] = "TPU" + torch.set_grad_enabled(False) + torch.set_default_dtype(self.model_config.dtype) + + # Initialize the distributed environment. + init_tpu_worker_distributed_environment(self.parallel_config, + self.rank, + self.distributed_init_method, + self.local_rank) + + # Device initialization should happen after initializing + # the distributed runtime. + self.device = xm.xla_device() + self.device_config.device = self.device + + # Set random seed. + set_random_seed(self.model_config.seed) + xm.set_rng_state(self.model_config.seed, self.device) + + # Increase the cache size limit, which is the maximum number of + # dynamo graphs that can be compiled. + # NOTE(woosuk): Usually, we compile 10-15 graphs for prefill and + # 30-40 graphs for decode. 128 is an arbitrary safe number. + torch._dynamo.config.cache_size_limit = 128 + # Use persistent cache to avoid XLA recompilation. + # NOTE(woosuk): Set per-rank cache path since different ranks + # can have slightly different XLA graphs. + world_size = self.parallel_config.world_size + rank = xr.global_ordinal() + per_rank_path = os.path.join(envs.VLLM_XLA_CACHE_PATH, + f"tp{world_size}_rank{rank}") + xr.initialize_cache(per_rank_path, readonly=False) + + # Init ModelRunner here, so that we have access to self.device. + self.model_runner = TPUModelRunner(self.vllm_config, self.device) + + def determine_available_memory(self) -> int: + kv_caches: Dict[str, torch.Tensor] = {} + kv_cache_spec = self.model_runner.get_kv_cache_spec() + for layer_name, layer_spec in kv_cache_spec.items(): + if isinstance(layer_spec, FullAttentionSpec): + dtype = layer_spec.dtype + + # Use an empty tensor instead of `None`` to force Dynamo to pass + # it by reference, rather by specializing on the value ``None``. + tpu_k_cache = torch.tensor([], dtype=dtype, device=self.device) + tpu_v_cache = torch.tensor([], dtype=dtype, device=self.device) + + kv_caches[layer_name] = (tpu_k_cache, tpu_v_cache) + else: + raise NotImplementedError + + runner_kv_caches: List[torch.Tensor] = [] + bind_kv_cache( + kv_caches, + self.vllm_config.compilation_config.static_forward_context, + runner_kv_caches) + + self.model_runner.dummy_run( + runner_kv_caches, + num_tokens=1, + seq_len=self.scheduler_config.max_num_batched_tokens, + exec_mode=ExecutionMode.PREFILL, + ) + + # Synchronize before measuring the memory usage. + xm.wait_device_ops() + + # Get the maximum amount of memory used by the model weights and + # intermediate activations. + m = xm.get_memory_info(self.device) + total_memory_size = m["bytes_limit"] + profiled = m["peak_bytes_used"] # Weights + intermediate activations. + + # Calculate the TPU KV cache size based on profiling. + usable_memory_size = int(total_memory_size * + self.cache_config.gpu_memory_utilization) + tpu_kv_cache_bytes = max(usable_memory_size - profiled, 0) + + return int(tpu_kv_cache_bytes) + + def execute_model( + self, + scheduler_output: "SchedulerOutput", + ) -> Optional[ModelRunnerOutput]: + output = self.model_runner.execute_model(scheduler_output) + return output if self.rank == 0 else None + + def load_model(self) -> None: + self.model_runner.load_model() + + def compile_or_warm_up_model(self) -> None: + if not self.model_config.enforce_eager: + self.model_runner.capture_model() + + # Reset the seed to ensure that the random state is not affected by + # the model initialization and profiling. + set_random_seed(self.model_config.seed) + + def get_model(self) -> nn.Module: + return self.model_runner.get_model() + + def get_kv_cache_spec(self) -> KVCacheSpec: + return self.model_runner.get_kv_cache_spec() + + def initialize_cache(self, kv_cache_configs: List[KVCacheConfig]) -> None: + """Allocate GPU KV cache with the specified kv_cache_config.""" + kv_cache_config = kv_cache_configs[self.rank] + self.model_runner.initialize_kv_cache(kv_cache_config) + + def check_health(self) -> None: + # worker will always be healthy as long as it's running. + return + + +def init_tpu_worker_distributed_environment( + parallel_config: ParallelConfig, + rank: int, + distributed_init_method: Optional[str] = None, + local_rank: int = -1, +) -> None: + """Initialize the distributed environment.""" + + # NOTE(woosuk): This is just to initialize the TP group and broadcast + # the input objects on CPU. The all-reduce and all-gather ops on TPU + # are invoked by `xm.all_reduce` and `xm.all_gather` which use their + # own context. + init_distributed_environment( + world_size=parallel_config.world_size, + rank=rank, + local_rank=local_rank, + distributed_init_method=distributed_init_method, + backend="gloo", + ) + ensure_model_parallel_initialized(parallel_config.tensor_parallel_size, + parallel_config.pipeline_parallel_size)