From c715fb19e5da39025823f383cd22e0c13083e494 Mon Sep 17 00:00:00 2001 From: Alexander Matveev Date: Thu, 9 Jan 2025 17:00:34 +0000 Subject: [PATCH] [V1] TPU support Signed-off-by: Alexander Matveev --- .pre-commit-config.yaml | 2 +- examples/offline_inference/basic.py | 6 +- tests/entrypoints/openai/test_accuracy.py | 13 +- tools/mypy.sh | 2 +- vllm/platforms/cuda.py | 2 +- vllm/platforms/interface.py | 1 + vllm/platforms/tpu.py | 59 +- vllm/v1/attention/backends/pallas.py | 351 ++++++++++ vllm/v1/core/scheduler.py | 7 + vllm/v1/worker/gpu_input_batch.py | 2 +- vllm/v1/worker/gpu_model_runner.py | 144 +--- vllm/v1/worker/gpu_worker.py | 170 ++--- vllm/v1/worker/model_runner_base.py | 142 ++++ vllm/v1/worker/tpu_model_runner.py | 778 ++++++++++++++++++++++ vllm/v1/worker/tpu_worker.py | 141 ++++ vllm/v1/worker/worker_base.py | 173 +++++ 16 files changed, 1752 insertions(+), 241 deletions(-) create mode 100644 vllm/v1/attention/backends/pallas.py create mode 100644 vllm/v1/worker/model_runner_base.py create mode 100644 vllm/v1/worker/tpu_model_runner.py create mode 100644 vllm/v1/worker/tpu_worker.py create mode 100644 vllm/v1/worker/worker_base.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 432bf5ed18dbc..a90ddcddf5db2 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -89,4 +89,4 @@ repos: name: Suggestion entry: bash -c 'echo "To bypass pre-commit hooks, add --no-verify to git commit."' language: system - verbose: true + verbose: true \ No newline at end of file diff --git a/examples/offline_inference/basic.py b/examples/offline_inference/basic.py index 23cc6e8539431..d80a354259db5 100644 --- a/examples/offline_inference/basic.py +++ b/examples/offline_inference/basic.py @@ -8,10 +8,10 @@ prompts = [ "The future of AI is", ] # Create a sampling params object. -sampling_params = SamplingParams(temperature=0.8, top_p=0.95) +sampling_params = SamplingParams() #temperature=0.8, top_p=0.95) # Create an LLM. -llm = LLM(model="facebook/opt-125m") +llm = LLM(model="Qwen/Qwen2-1.5B-Instruct", max_model_len=512, max_num_seqs=16) # Generate texts from the prompts. The output is a list of RequestOutput objects # that contain the prompt, generated text, and other information. outputs = llm.generate(prompts, sampling_params) @@ -19,4 +19,4 @@ outputs = llm.generate(prompts, sampling_params) for output in outputs: prompt = output.prompt generated_text = output.outputs[0].text - print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") \ No newline at end of file + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") diff --git a/tests/entrypoints/openai/test_accuracy.py b/tests/entrypoints/openai/test_accuracy.py index b1d4461d164aa..976c8f3473d1a 100644 --- a/tests/entrypoints/openai/test_accuracy.py +++ b/tests/entrypoints/openai/test_accuracy.py @@ -66,14 +66,21 @@ def run_test(more_args): ), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}" -@pytest.mark.skipif(not current_platform.is_cuda(), - reason="V1 currently only supported on CUDA") +@pytest.mark.skipif(not current_platform.is_cuda() + and not current_platform.is_tpu(), + reason="V1 currently only supported on CUDA and TPU") def test_lm_eval_accuracy_v1_engine(monkeypatch): """Run with the V1 Engine.""" with monkeypatch.context() as m: m.setenv("VLLM_USE_V1", "1") - run_test([]) + more_args = [] + + # Limit compilation time for V1 + if current_platform.is_tpu(): + more_args = ["--max-num-seqs", "64"] + + run_test(more_args) @pytest.mark.parametrize("more_args", MORE_ARGS_LIST) diff --git a/tools/mypy.sh b/tools/mypy.sh index 77d342da1ec82..d477a7ecfbe82 100755 --- a/tools/mypy.sh +++ b/tools/mypy.sh @@ -34,4 +34,4 @@ run_mypy vllm/plugins run_mypy vllm/prompt_adapter run_mypy vllm/spec_decode run_mypy vllm/worker -run_mypy vllm/v1 +run_mypy vllm/v1 \ No newline at end of file diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 2587e3a11dde3..5ca3aef441062 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -135,7 +135,7 @@ class CudaPlatformBase(Platform): else: if envs.VLLM_USE_V1: parallel_config.worker_cls = \ - "vllm.v1.worker.gpu_worker.Worker" + "vllm.v1.worker.gpu_worker.GPUWorker" else: parallel_config.worker_cls = "vllm.worker.worker.Worker" diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index f2ecec3203fb7..7607beb766e33 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -32,6 +32,7 @@ class _Backend(enum.Enum): FLASHINFER = enum.auto() HPU_ATTN = enum.auto() PALLAS = enum.auto() + PALLAS_VLLM_V1 = enum.auto() IPEX = enum.auto() BLOCK_SPARSE_FLASH_ATTN = enum.auto() NO_ATTENTION = enum.auto() diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py index 05a3aa4305cfa..65bfff7311cd4 100644 --- a/vllm/platforms/tpu.py +++ b/vllm/platforms/tpu.py @@ -2,6 +2,7 @@ from typing import TYPE_CHECKING, Optional import torch +import vllm.envs as envs from vllm.logger import init_logger from .interface import Platform, PlatformEnum, _Backend @@ -30,10 +31,16 @@ class TpuPlatform(Platform): def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int, dtype: torch.dtype, kv_cache_dtype: Optional[str], block_size: int, use_v1: bool) -> str: - if selected_backend != _Backend.PALLAS: + if (selected_backend != _Backend.PALLAS + and selected_backend != _Backend.PALLAS_VLLM_V1): logger.info("Cannot use %s backend on TPU.", selected_backend) - logger.info("Using Pallas backend.") - return "vllm.attention.backends.pallas.PallasAttentionBackend" + + if use_v1: + logger.info("Using Pallas V1 backend.") + return "vllm.v1.attention.backends.pallas.PallasAttentionBackend" + else: + logger.info("Using Pallas backend.") + return "vllm.attention.backends.pallas.PallasAttentionBackend" @classmethod def get_device_name(cls, device_id: int = 0) -> str: @@ -45,7 +52,7 @@ class TpuPlatform(Platform): @classmethod def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool: - return True + return not envs.VLLM_USE_V1 @classmethod def inference_mode(cls): @@ -60,11 +67,11 @@ class TpuPlatform(Platform): cache_config.block_size = 16 compilation_config = vllm_config.compilation_config - if compilation_config.level == CompilationLevel.NO_COMPILATION: - # TPU does not support NO_COMPILATION + + # TPU only supports DYNAMO_ONCE compilation level + if compilation_config.level != CompilationLevel.DYNAMO_ONCE: + logger.info("[TPU] Forcing DYNAMO_ONCE compilation level") compilation_config.level = CompilationLevel.DYNAMO_ONCE - assert compilation_config.level < CompilationLevel.PIECEWISE,\ - "TPU does not support Inductor." if compilation_config.backend == "": compilation_config.backend = "openxla" @@ -72,10 +79,6 @@ class TpuPlatform(Platform): assert vllm_config.speculative_config is None, \ "TPU does not support speculative decoding" - assert not vllm_config.scheduler_config.chunked_prefill_enabled, ( - "Chunked prefill is not yet supported for TPU backend") - assert not vllm_config.speculative_config, ( - "Speculative decoding is not yet supported for TPU backend") if vllm_config.model_config.dtype in (torch.float16, torch.float32): logger.warning( "The TPU backend currently does not support %s. " @@ -85,8 +88,34 @@ class TpuPlatform(Platform): parallel_config = vllm_config.parallel_config scheduler_config = vllm_config.scheduler_config if parallel_config.worker_cls == "auto": - if scheduler_config.is_multi_step: + if envs.VLLM_USE_V1: parallel_config.worker_cls = \ - "vllm.worker.multi_step_tpu_worker.MultiStepTPUWorker" + "vllm.v1.worker.tpu_worker.TPUWorker" else: - parallel_config.worker_cls = "vllm.worker.tpu_worker.TPUWorker" + if scheduler_config.is_multi_step: + parallel_config.worker_cls = \ + "vllm.worker.multi_step_tpu_worker.MultiStepTPUWorker" + else: + parallel_config.worker_cls = \ + "vllm.worker.tpu_worker.TPUWorker" + + # Adjust scheduler config for V1 + # TODO: Add support for these + if envs.VLLM_USE_V1: + if vllm_config.cache_config.enable_prefix_caching: + logger.warning("[V1][TPU] Disable prefix caching") + vllm_config.cache_config.enable_prefix_caching = False + + if vllm_config.scheduler_config.chunked_prefill_enabled: + logger.warning("[V1][TPU] Disable chunked prefill") + vllm_config.scheduler_config.chunked_prefill_enabled = False + + assert not vllm_config.scheduler_config.chunked_prefill_enabled, ( + "Chunked prefill is not yet supported for TPU backend") + assert not vllm_config.speculative_config, ( + "Speculative decoding is not yet supported for TPU backend") + + @classmethod + def is_pin_memory_available(cls): + logger.warning("Pin memory is not supported on TPU.") + return False diff --git a/vllm/v1/attention/backends/pallas.py b/vllm/v1/attention/backends/pallas.py new file mode 100644 index 0000000000000..49d886af9dd78 --- /dev/null +++ b/vllm/v1/attention/backends/pallas.py @@ -0,0 +1,351 @@ +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Type + +import torch +import torch_xla.experimental.custom_kernel # Required to register custom ops. + +from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, + AttentionLayer, + AttentionMetadata, AttentionType) +from vllm.attention.backends.utils import CommonAttentionState + + +class PallasAttentionBackend(AttentionBackend): + + @staticmethod + def get_name() -> str: + return "PALLAS_VLLM_V1" + + @staticmethod + def get_impl_cls() -> Type["PallasAttentionBackendImpl"]: + return PallasAttentionBackendImpl + + @staticmethod + def get_metadata_cls() -> Type["PallasMetadata"]: + return PallasMetadata + + @staticmethod + def get_state_cls() -> Type["CommonAttentionState"]: + return CommonAttentionState + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + ) -> Tuple[int, ...]: + return (num_kv_heads, num_blocks, block_size, head_size) + + @staticmethod + def swap_blocks( + src_kv_cache: torch.Tensor, + dst_kv_cache: torch.Tensor, + src_to_dst: torch.Tensor, + ) -> None: + raise RuntimeError("swap_blocks is not used for the TPU backend.") + + @torch.compile(backend="openxla") + @staticmethod + def copy_blocks( + kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], + src_to_dists: Tuple[torch.Tensor, torch.Tensor], + ) -> None: + src_indices, dst_indices = src_to_dists + for k_cache, v_cache in kv_caches: + torch.ops.xla.dynamo_set_buffer_donor_(k_cache, True) + k_cache[:, dst_indices] = k_cache[:, src_indices] + torch.ops.xla.dynamo_set_buffer_donor_(v_cache, True) + v_cache[:, dst_indices] = v_cache[:, src_indices] + + +@dataclass +class PallasMetadata(AttentionMetadata): + + # Currently, input sequences can only contain all prefills + # or all decoding. + block_tables: Optional[torch.Tensor] = None + context_lens: Optional[torch.Tensor] = None + effective_query_lens: Optional[torch.Tensor] = None + + @property + def prefill_metadata(self) -> Optional["PallasMetadata"]: + if self.num_prefills == 0: + return None + + assert self.num_decode_tokens == 0 + return self + + @property + def decode_metadata(self) -> Optional["PallasMetadata"]: + if self.num_decode_tokens == 0: + return None + + assert self.num_prefills == 0 + assert self.num_prefill_tokens == 0 + assert self.block_tables is not None + assert self.context_lens is not None + return self + + +class PallasAttentionBackendImpl(AttentionImpl): + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[List[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + blocksparse_params: Optional[Dict[str, Any]] = None, + logits_soft_cap: Optional[float] = None, + attn_type: str = AttentionType.DECODER, + ) -> None: + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads + + assert self.num_heads % self.num_kv_heads == 0 + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + if head_size % 128 != 0: + raise NotImplementedError("Head size must be a multiple of 128.") + if alibi_slopes is not None: + raise NotImplementedError("Alibi slopes is not supported.") + if sliding_window is not None: + raise NotImplementedError("Sliding window is not supported.") + if kv_cache_dtype != "auto": + raise NotImplementedError("FP8 KV cache dtype is not supported.") + if blocksparse_params is not None: + raise NotImplementedError("Blocksparse is not supported.") + if logits_soft_cap is not None: + raise NotImplementedError( + "Attention logits soft-capping is not supported.") + + if torch_xla.tpu.version() < 4: + raise NotImplementedError("TPU version must be 4 or higher.") + + self.megacore_mode = None + tpu_env = torch_xla.tpu.get_tpu_env() + tpu_type = (tpu_env.get("ACCELERATOR_TYPE", None) + or tpu_env.get("TYPE", None) + or tpu_env.get("TPU_ACCELERATOR_TYPE", None)) + assert tpu_type is not None + tpu_type = tpu_type.lower() + + if (("lite" not in tpu_type) and ("v6" not in tpu_type)): + if self.num_kv_heads % 2 == 0: + self.megacore_mode = "kv_head" + else: + # NOTE(woosuk): If the batch size is not a multiple of 2, the + # megacore mode will be None. + self.megacore_mode = "batch" + + if attn_type != AttentionType.DECODER: + raise NotImplementedError("Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "PallasAttentionBackendImpl") + + def forward( + self, + layer: AttentionLayer, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: Tuple[torch.Tensor, torch.Tensor], + attn_metadata: PallasMetadata, + output: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Forward pass with Pallas attention. + + Args: + query: shape = [batch_size, seq_len, num_heads * head_size] + key: shape = [batch_size, seq_len, num_kv_heads * head_size] + value: shape = [batch_size, seq_len, num_kv_heads * head_size] + kv_cache[0] = [num_kv_heads, num_blocks, block_size, head_size] + kv_cache[1] = [num_kv_heads, num_blocks, block_size, head_size] + NOTE: kv_cache[0] and kv_cache[1] will be an empty tensor + with shape [0] for profiling run. + attn_metadata: Metadata for attention. + Returns: + shape = [batch_size, seq_len, num_heads * head_size] + """ + + if attn_metadata is None: + if output is None: + output = torch.ones_like(query) + return output + + assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0 + batch_size, seq_len, hidden_size = query.shape + query = query.view(batch_size, seq_len, self.num_heads, self.head_size) + key = key.view(batch_size, seq_len, self.num_kv_heads, self.head_size) + value = value.view(batch_size, seq_len, self.num_kv_heads, + self.head_size) + + if kv_cache[0].numel() > 0: + slot_mapping = attn_metadata.slot_mapping + key_cache, value_cache = kv_cache + write_to_kv_cache(key, value, key_cache, value_cache, slot_mapping) + + query = query * self.scale + if attn_metadata.num_prefills > 0: + if attn_metadata.block_tables is None: + # Prefill without paged KV cache. + assert seq_len % 16 == 0, ( + "Pallas FlashAttention kernel requires seq_len to be a " + f"multiple of 16 but got {seq_len}") + + # Handle GQA/MQA. + if self.num_kv_heads != self.num_heads: + key = key.repeat_interleave(self.num_queries_per_kv, + dim=-2) + key = key.view(batch_size, seq_len, self.num_heads, + self.head_size) + value = value.repeat_interleave(self.num_queries_per_kv, + dim=-2) + value = value.view(batch_size, seq_len, self.num_heads, + self.head_size) + # FlashAttention kernel requires the input shape to be + # [batch_size, num_heads, seq_len, d_model] + # while the input is [batch_size, seq_len, num_heads, d_model]. + # Permute the input to match the required format. + output = torch.ops.xla.flash_attention( + query.permute(0, 2, 1, 3), + key.permute(0, 2, 1, 3), + value.permute(0, 2, 1, 3), + True, + ) + output = output.permute(0, 2, 1, 3) + else: + # Prefill with paged KV cache. + # TODO(woosuk): Tune the below knobs. + num_kv_pages_per_compute_block = 16 + num_queries_per_compute_block = 16 + assert seq_len % num_queries_per_compute_block == 0 + output = torch.ops.xla.multi_queries_paged_attention( + query, + key_cache, + value_cache, + attn_metadata.context_lens, + attn_metadata.block_tables, + attn_metadata.effective_query_lens, + num_kv_pages_per_compute_block, + num_queries_per_compute_block, + use_kernel=True, + ) + else: + # Decoding run. + assert kv_cache[0].numel() > 0 + query = query.squeeze(dim=1) + pages_per_compute_block = 16 # TODO(woosuk): Tune this value. + + assert attn_metadata.block_tables is not None + assert attn_metadata.context_lens is not None + # NOTE(woosuk): The PagedAttention Pallas kernel stores the entire + # block table in SMEM. Therefore, if the block table is too large, + # the kernel compilation will fail. To avoid this, we split the + # batch dimension into smaller chunks and run the kernel multiple + # times. + MAX_SMEM_USAGE = 512 * 1024 + size_per_seq = 4 * attn_metadata.block_tables.shape[1] + max_num_seq = MAX_SMEM_USAGE // size_per_seq + + if batch_size <= max_num_seq: + output = paged_attention( + query, + key_cache, + value_cache, + attn_metadata.context_lens, + attn_metadata.block_tables, + pages_per_compute_block, + self.megacore_mode, + ) + else: + chunk_size = max_num_seq + # Make sure the chunk size is a multiple of 2. + chunk_size = chunk_size // 2 * 2 + num_chunks = (batch_size + chunk_size - 1) // chunk_size + + output = torch.empty_like(query) + for chunk_idx in range(num_chunks): + chunk_start = chunk_idx * chunk_size + chunk_end = chunk_start + chunk_size + # NOTE(woosuk): We skip this line because it causes Dynamo + # compilation error. Instead, we rely on the slice operation + # to handle the out-of-bound case. + # chunk_end = min(chunk_end, batch_size) + chunk_output = paged_attention( + query[chunk_start:chunk_end], + key_cache, + value_cache, + attn_metadata.context_lens[chunk_start:chunk_end], + attn_metadata.block_tables[chunk_start:chunk_end], + pages_per_compute_block, + self.megacore_mode, + ) + output[chunk_start:chunk_end] = chunk_output + + # Reshape the output tensor. + return output.reshape(batch_size, seq_len, hidden_size) + + +def write_to_kv_cache( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, +) -> None: + torch.ops.xla.dynamo_set_buffer_donor_(key_cache, True) + torch.ops.xla.dynamo_set_buffer_donor_(value_cache, True) + + key = key.flatten(0, 2) + value = value.flatten(0, 2) + key_cache = key_cache.flatten(0, 2) + value_cache = value_cache.flatten(0, 2) + key_cache.index_copy_(0, slot_mapping, key) + value_cache.index_copy_(0, slot_mapping, value) + + +def paged_attention( + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + context_lens: torch.Tensor, + block_tables: torch.Tensor, + pages_per_compute_block: int, + megacore_mode: Optional[str], +) -> torch.Tensor: + batch_size = query.shape[0] + if megacore_mode == "batch" and batch_size % 2 != 0: + megacore_mode = None + else: + megacore_mode = megacore_mode + + # NOTE(woosuk): A temporary workaround to avoid the error: + # "xla::paged_attention() Expected a value of type 'str' for + # argument 'megacore_mode' but instead found type 'NoneType'." + if megacore_mode is not None: + output = torch.ops.xla.paged_attention( + query, + key_cache, + value_cache, + context_lens, + block_tables, + pages_per_compute_block, + megacore_mode=megacore_mode, + ) + else: + output = torch.ops.xla.paged_attention( + query, + key_cache, + value_cache, + context_lens, + block_tables, + pages_per_compute_block, + ) + return output diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index 8ded5e5787133..abd5285de6b16 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -212,6 +212,13 @@ class Scheduler: num_computed_tokens -= self.block_size num_new_tokens = self.block_size computed_blocks.pop() + + # If chunked prefill is not enabled, then breakout of the loop + # when above budget. + if (not self.scheduler_config.chunked_prefill_enabled + and num_new_tokens > token_budget): + break + num_new_tokens = min(num_new_tokens, token_budget) assert num_new_tokens > 0 diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index 28d8e39053874..07187968bee7c 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -72,7 +72,7 @@ class InputBatch: self.token_ids_cpu = self.token_ids_cpu_tensor.numpy() self.num_tokens = np.zeros(max_num_reqs, dtype=np.int32) self.num_prompt_tokens = np.zeros(max_num_reqs, dtype=np.int32) - self.num_computed_tokens_cpu = np.empty(max_num_reqs, dtype=np.int32) + self.num_computed_tokens_cpu = np.zeros(max_num_reqs, dtype=np.int32) # Block table. self.block_table = BlockTable( diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 4b3c325ded906..1f5b2f67079ca 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -5,10 +5,7 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, cast import numpy as np import torch import torch.distributed -import torch.nn as nn -from vllm.attention.backends.abstract import AttentionType -from vllm.attention.layer import Attention from vllm.config import CompilationLevel, VllmConfig from vllm.distributed.parallel_state import graph_capture from vllm.forward_context import set_forward_context @@ -19,18 +16,17 @@ from vllm.model_executor.model_loader import get_model from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs from vllm.multimodal.utils import group_mm_inputs_by_modality from vllm.sampling_params import SamplingType -from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, - LayerBlockType, cdiv, is_pin_memory_available) +from vllm.utils import DeviceMemoryProfiler, cdiv from vllm.v1.attention.backends.flash_attn import (FlashAttentionBackend, FlashAttentionMetadata) from vllm.v1.core.encoder_cache_manager import compute_encoder_budget from vllm.v1.engine.mm_input_mapper import MMInputMapperClient -from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, - KVCacheSpec) +from vllm.v1.kv_cache_interface import FullAttentionSpec, KVCacheConfig from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.utils import bind_kv_cache from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch +from vllm.v1.worker.model_runner_base import ExecutionMode, ModelRunnerBase if TYPE_CHECKING: from vllm.v1.core.scheduler import SchedulerOutput @@ -38,53 +34,30 @@ if TYPE_CHECKING: logger = init_logger(__name__) -class GPUModelRunner: +class GPUModelRunner(ModelRunnerBase): def __init__( self, vllm_config: VllmConfig, device: torch.device, ): - self.vllm_config = vllm_config - self.model_config = vllm_config.model_config - self.cache_config = vllm_config.cache_config - self.lora_config = vllm_config.lora_config - self.load_config = vllm_config.load_config - self.parallel_config = vllm_config.parallel_config - self.scheduler_config = vllm_config.scheduler_config - self.speculative_config = vllm_config.speculative_config - self.prompt_adapter_config = vllm_config.prompt_adapter_config - self.observability_config = vllm_config.observability_config + super().__init__(vllm_config, device) - model_config = self.model_config - cache_config = self.cache_config - scheduler_config = self.scheduler_config - parallel_config = self.parallel_config - self.device = device - self.pin_memory = is_pin_memory_available() - self.dtype = self.model_config.dtype - if cache_config.cache_dtype == "auto": - self.kv_cache_dtype = self.dtype - else: - self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[ - cache_config.cache_dtype] + # Persistent batch. + self.input_batch = InputBatch( + max_num_reqs=self.max_num_reqs, + max_model_len=self.max_model_len, + max_num_blocks_per_req=self.max_num_blocks_per_req, + device=self.device, + pin_memory=self.pin_memory, + vocab_size=self.model_config.get_vocab_size(), + ) - self.is_multimodal_model = model_config.is_multimodal_model - self.sliding_window = model_config.get_sliding_window() - self.block_size = cache_config.block_size - self.max_model_len = model_config.max_model_len - self.max_num_blocks_per_req = cdiv(self.max_model_len, self.block_size) - self.max_num_tokens = scheduler_config.max_num_batched_tokens - self.max_num_reqs = scheduler_config.max_num_seqs + # Request states. + self.requests: Dict[str, CachedRequestState] = {} - # Model-related. - self.num_attn_layers = model_config.get_num_layers_by_block_type( - parallel_config, LayerBlockType.attention) - self.num_query_heads = model_config.get_num_attention_heads( - parallel_config) - self.num_kv_heads = model_config.get_num_kv_heads(parallel_config) - self.head_size = model_config.get_head_size() - self.hidden_size = model_config.get_hidden_size() + # KV caches for forward pass + self.kv_caches: List[torch.Tensor] = [] # Multi-modal data support self.input_registry = INPUT_REGISTRY @@ -96,30 +69,15 @@ class GPUModelRunner: self.mm_input_mapper_profiling.use_cache = False encoder_compute_budget, encoder_cache_size = compute_encoder_budget( - model_config=model_config, - scheduler_config=scheduler_config, + model_config=self.model_config, + scheduler_config=self.scheduler_config, ) self.max_num_encoder_input_tokens = encoder_compute_budget self.encoder_cache_size = encoder_cache_size - # Lazy initialization - # self.model: nn.Module # Set after load_model - self.kv_caches: List[torch.Tensor] = [] # req_id -> (input_id -> encoder_output) self.encoder_cache: Dict[str, Dict[int, torch.Tensor]] = {} - # Request states. - self.requests: Dict[str, CachedRequestState] = {} - # Persistent batch. - self.input_batch = InputBatch( - max_num_reqs=self.max_num_reqs, - max_model_len=self.max_model_len, - max_num_blocks_per_req=self.max_num_blocks_per_req, - device=self.device, - pin_memory=self.pin_memory, - vocab_size=model_config.get_vocab_size(), - ) - self.use_cuda_graph = (self.vllm_config.compilation_config.level == CompilationLevel.PIECEWISE and not self.model_config.enforce_eager) @@ -611,6 +569,8 @@ class GPUModelRunner: return sampling_metadata def _execute_encoder(self, scheduler_output: "SchedulerOutput"): + assert self.model is not None + scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs if not scheduled_encoder_inputs: return @@ -698,14 +658,13 @@ class GPUModelRunner: encoder_outputs.append(encoder_output[start_idx:end_idx]) return encoder_outputs - def get_model(self) -> nn.Module: - return self.model - @torch.inference_mode() def execute_model( self, scheduler_output: "SchedulerOutput", ) -> ModelRunnerOutput: + assert self.model is not None + self._update_states(scheduler_output) if self.is_multimodal_model: @@ -833,14 +792,15 @@ class GPUModelRunner: self.model_memory_usage / float(2**30)) @torch.inference_mode() - def _dummy_run( + def dummy_run( self, + kv_caches, num_tokens: int, - kv_caches: Optional[List[torch.Tensor]] = None, + seq_len: Optional[int] = None, + exec_mode: Optional[ExecutionMode] = None, ) -> torch.Tensor: - model = self.model - if kv_caches is None: - kv_caches = self.kv_caches + assert self.model is not None + if self.is_multimodal_model: input_ids = None inputs_embeds = self.inputs_embeds[:num_tokens] @@ -851,7 +811,7 @@ class GPUModelRunner: positions = self.mrope_positions[:, :num_tokens] \ if self.model_config.uses_mrope \ else self.positions[:num_tokens] - hidden_states = model( + hidden_states = self.model( input_ids=input_ids, positions=positions, kv_caches=kv_caches, @@ -861,6 +821,7 @@ class GPUModelRunner: return hidden_states def profile_run(self) -> None: + assert self.model is not None # use an empty tensor instead of `None`` to force Dynamo to pass # it by reference, rather by specializing on the value `None`. # the `dtype` argument does not matter, and we use `float32` as @@ -966,7 +927,7 @@ class GPUModelRunner: self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs)) # Trigger compilation for general shape. - hidden_states = self._dummy_run(self.max_num_tokens, dummy_kv_caches) + hidden_states = self.dummy_run(dummy_kv_caches, self.max_num_tokens) logits = self.model.compute_logits(hidden_states, None) logits = logits[:self.max_num_tokens] # TODO(woosuk): Consider the memory usage of the sampler. @@ -992,8 +953,8 @@ class GPUModelRunner: for num_tokens in reversed(self.cudagraph_batch_sizes): for _ in range(self.vllm_config.compilation_config. cudagraph_num_of_warmups): - self._dummy_run(num_tokens) - self._dummy_run(num_tokens) + self.dummy_run(None, num_tokens) + self.dummy_run(None, num_tokens) end_time = time.perf_counter() end_free_gpu_memory = torch.cuda.mem_get_info()[0] @@ -1036,38 +997,3 @@ class GPUModelRunner: kv_caches, self.vllm_config.compilation_config.static_forward_context, self.kv_caches) - - def get_kv_cache_spec(self) -> KVCacheSpec: - """ - Generates the KVCacheSpec by parsing the kv cache format from each - Attention module in the static forward context. - Returns: - KVCacheSpec: A dictionary mapping layer names to their KV cache - format. Layers that do not need KV cache are not included. - """ - - forward_ctx = self.vllm_config.compilation_config.static_forward_context - block_size = self.vllm_config.cache_config.block_size - kv_cache_spec: KVCacheSpec = {} - for layer_name, attn_module in forward_ctx.items(): - # TODO: Support other attention modules, e.g., sliding window, - # cross-attention, MLA. - assert isinstance(attn_module, Attention) - if attn_module.attn_type == AttentionType.DECODER: - kv_cache_spec[layer_name] = FullAttentionSpec( - block_size=block_size, - num_kv_heads=attn_module.num_kv_heads, - head_size=attn_module.head_size, - dtype=attn_module.dtype, - ) - elif attn_module.attn_type in (AttentionType.ENCODER, - AttentionType.ENCODER_ONLY): - # encoder-only attention does not need KV cache. - continue - elif attn_module.attn_type == AttentionType.ENCODER_DECODER: - raise NotImplementedError - else: - raise ValueError( - f"Unknown attention type: {attn_module.attn_type}") - - return kv_cache_spec diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index a8cf0aec3f17b..1aaca42713c58 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -1,13 +1,11 @@ """A GPU worker class.""" import gc import os -from typing import TYPE_CHECKING, Optional +from typing import Optional import torch import torch.distributed -import torch.nn as nn -import vllm.envs as envs from vllm.config import ParallelConfig, VllmConfig from vllm.device_allocator.cumem import CuMemAllocator from vllm.distributed import (ensure_model_parallel_initialized, @@ -15,20 +13,17 @@ from vllm.distributed import (ensure_model_parallel_initialized, set_custom_all_reduce) from vllm.logger import init_logger from vllm.model_executor import set_random_seed -from vllm.platforms import current_platform from vllm.utils import GiB_bytes from vllm.v1.core.scheduler import SchedulerOutput -from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec +from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.worker.gpu_model_runner import GPUModelRunner +from vllm.v1.worker.worker_base import WorkerBase, check_if_gpu_supports_dtype logger = init_logger(__name__) -if TYPE_CHECKING: - from vllm.v1.core.scheduler import SchedulerOutput - -class Worker: +class GPUWorker(WorkerBase): def __init__( self, @@ -38,46 +33,8 @@ class Worker: distributed_init_method: str, is_driver_worker: bool = False, ): - - # TODO: use WorkerBase.__init__(self, vllm_config=vllm_config) - self.vllm_config = vllm_config - self.model_config = vllm_config.model_config - self.cache_config = vllm_config.cache_config - self.lora_config = vllm_config.lora_config - self.load_config = vllm_config.load_config - self.parallel_config = vllm_config.parallel_config - self.scheduler_config = vllm_config.scheduler_config - self.device_config = vllm_config.device_config - self.speculative_config = vllm_config.speculative_config - self.prompt_adapter_config = vllm_config.prompt_adapter_config - self.observability_config = vllm_config.observability_config - - self.parallel_config.rank = rank - self.local_rank = local_rank - self.rank = rank - self.distributed_init_method = distributed_init_method - - if self.model_config.trust_remote_code: - # note: lazy import to avoid importing torch before initializing - from vllm.utils import init_cached_hf_modules - init_cached_hf_modules() - - # Torch profiler. Enabled and configured through env vars: - # VLLM_TORCH_PROFILER_DIR=/path/to/save/trace - if envs.VLLM_TORCH_PROFILER_DIR: - torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR - logger.info("Profiling enabled. Traces will be saved to: %s", - torch_profiler_trace_dir) - self.profiler = torch.profiler.profile( - activities=[ - torch.profiler.ProfilerActivity.CPU, - torch.profiler.ProfilerActivity.CUDA, - ], - with_stack=True, - on_trace_ready=torch.profiler.tensorboard_trace_handler( - torch_profiler_trace_dir, use_gzip=True)) - else: - self.profiler = None + super().__init__(vllm_config, local_rank, rank, + distributed_init_method) def sleep(self, level: int = 1) -> None: free_bytes_before_sleep = torch.cuda.mem_get_info()[0] @@ -97,31 +54,39 @@ class Worker: allocator.wake_up() def init_device(self): - if self.device_config.device.type == "cuda": - # torch.distributed.all_reduce does not free the input tensor until - # the synchronization point. This causes the memory usage to grow - # as the number of all_reduce calls increases. This env var disables - # this behavior. - # Related issue: - # https://discuss.pytorch.org/t/cuda-allocation-lifetime-for-inputs-to-distributed-all-reduce/191573 - os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1" + assert self.device_config.device.type == "cuda" - # This env var set by Ray causes exceptions with graph building. - os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None) - self.device = torch.device(f"cuda:{self.local_rank}") - torch.cuda.set_device(self.device) + # torch.distributed.all_reduce does not free the input tensor until + # the synchronization point. This causes the memory usage to grow + # as the number of all_reduce calls increases. This env var disables + # this behavior. + # Related issue: + # https://discuss.pytorch.org/t/cuda-allocation-lifetime-for-inputs-to-distributed-all-reduce/191573 + os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1" + + # torch.distributed.all_reduce does not free the input tensor until + # the synchronization point. This causes the memory usage to grow + # as the number of all_reduce calls increases. This env var disables + # this behavior. + # Related issue: + # https://discuss.pytorch.org/t/cuda-allocation-lifetime-for-inputs-to-distributed-all-reduce/191573 + os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1" + + # This env var set by Ray causes exceptions with graph building. + os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None) + self.device = torch.device(f"cuda:{self.local_rank}") + torch.cuda.set_device(self.device) + + check_if_gpu_supports_dtype(self.model_config.dtype) + gc.collect() + torch.cuda.empty_cache() + self.init_gpu_memory = torch.cuda.mem_get_info()[0] - _check_if_gpu_supports_dtype(self.model_config.dtype) - gc.collect() - torch.cuda.empty_cache() - self.init_gpu_memory = torch.cuda.mem_get_info()[0] - else: - raise RuntimeError( - f"Not support device type: {self.device_config.device}") # Initialize the distributed environment. - init_worker_distributed_environment(self.parallel_config, self.rank, - self.distributed_init_method, - self.local_rank) + init_cuda_worker_distributed_environment(self.parallel_config, + self.rank, + self.distributed_init_method, + self.local_rank) # Set random seed. set_random_seed(self.model_config.seed) @@ -139,6 +104,7 @@ class Worker: from contextlib import nullcontext context = nullcontext() with context: + assert self.model_runner is not None self.model_runner.load_model() @torch.inference_mode() @@ -160,6 +126,7 @@ class Worker: _, total_gpu_memory = torch.cuda.mem_get_info() # Execute a forward pass with dummy inputs to profile the memory usage # of the model. + assert self.model_runner is not None self.model_runner.profile_run() free_gpu_memory, _ = torch.cuda.mem_get_info() @@ -191,9 +158,6 @@ class Worker: return int(available_kv_cache_memory) - def get_kv_cache_spec(self) -> KVCacheSpec: - return self.model_runner.get_kv_cache_spec() - def initialize_cache(self, kv_cache_config: KVCacheConfig) -> None: """Allocate GPU KV cache with the specified kv_cache_config.""" if self.vllm_config.model_config.enable_sleep_mode: @@ -203,9 +167,12 @@ class Worker: from contextlib import nullcontext context = nullcontext() with context: + assert self.model_runner is not None self.model_runner.initialize_kv_cache(kv_cache_config) def compile_or_warm_up_model(self) -> None: + assert self.model_runner is not None + # warm up sizes that are not in cudagraph capture sizes, # but users still want to compile for better performance, # e.g. for the max-num-batched token size in chunked prefill. @@ -217,44 +184,32 @@ class Worker: ] for size in sorted(warmup_sizes, reverse=True): logger.info("Compile and warming up model for size %d", size) - self.model_runner._dummy_run(size) + self.model_runner.dummy_run(None, size) + if not self.model_config.enforce_eager: self.model_runner.capture_model() # Reset the seed to ensure that the random state is not affected by # the model initialization and profiling. set_random_seed(self.model_config.seed) - def get_model(self) -> nn.Module: - return self.model_runner.get_model() - @torch.inference_mode() def execute_model( self, scheduler_output: "SchedulerOutput", ) -> Optional[ModelRunnerOutput]: + assert self.model_runner is not None output = self.model_runner.execute_model(scheduler_output) return output if self.rank == 0 else None - def profile(self, is_start: bool = True): - if self.profiler is None: - raise RuntimeError("Profiler is not enabled.") - if is_start: - self.profiler.start() - else: - self.profiler.stop() - def check_health(self) -> None: - # worker will always be healthy as long as it's running. - return - - -def init_worker_distributed_environment( +def init_cuda_worker_distributed_environment( parallel_config: ParallelConfig, rank: int, distributed_init_method: Optional[str] = None, local_rank: int = -1, ) -> None: """Initialize the distributed environment.""" + set_custom_all_reduce(not parallel_config.disable_custom_all_reduce) init_distributed_environment(parallel_config.world_size, rank, @@ -264,21 +219,22 @@ def init_worker_distributed_environment( parallel_config.pipeline_parallel_size) -def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype): - # Check if the GPU supports the dtype. - if torch_dtype == torch.bfloat16: # noqa: SIM102 - if not current_platform.has_device_capability(80): - capability = current_platform.get_device_capability() - gpu_name = current_platform.get_device_name() +# TODO: Remove +# def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype): +# # Check if the GPU supports the dtype. +# if torch_dtype == torch.bfloat16: # noqa: SIM102 +# if not current_platform.has_device_capability(80): +# capability = current_platform.get_device_capability() +# gpu_name = current_platform.get_device_name() - if capability is None: - compute_str = "does not have a compute capability" - else: - version_str = capability.as_version_str() - compute_str = f"has compute capability {version_str}" +# if capability is None: +# compute_str = "does not have a compute capability" +# else: +# version_str = capability.as_version_str() +# compute_str = f"has compute capability {version_str}" - raise ValueError( - "Bfloat16 is only supported on GPUs with compute capability " - f"of at least 8.0. Your {gpu_name} GPU {compute_str}. " - "You can use float16 instead by explicitly setting the" - "`dtype` flag in CLI, for example: --dtype=half.") +# raise ValueError( +# "Bfloat16 is only supported on GPUs with compute capability " +# f"of at least 8.0. Your {gpu_name} GPU {compute_str}. " +# "You can use float16 instead by explicitly setting the" +# "`dtype` flag in CLI, for example: --dtype=half.") diff --git a/vllm/v1/worker/model_runner_base.py b/vllm/v1/worker/model_runner_base.py new file mode 100644 index 0000000000000..e46242cb13841 --- /dev/null +++ b/vllm/v1/worker/model_runner_base.py @@ -0,0 +1,142 @@ +import enum +from typing import TYPE_CHECKING, Optional + +import torch +import torch.distributed +import torch.nn as nn + +from vllm.attention.backends.abstract import AttentionType +from vllm.attention.layer import Attention +from vllm.config import VllmConfig +from vllm.logger import init_logger +from vllm.utils import LayerBlockType, cdiv, is_pin_memory_available +from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, + KVCacheSpec) +from vllm.v1.outputs import ModelRunnerOutput + +if TYPE_CHECKING: + from vllm.v1.core.scheduler import SchedulerOutput + +logger = init_logger(__name__) + + +class ExecutionMode(enum.Enum): + PREFILL = enum.auto() + DECODE = enum.auto() + PREFIX_PREFILL = enum.auto() + + def is_prefill(self) -> bool: + return self in (ExecutionMode.PREFILL, ExecutionMode.PREFIX_PREFILL) + + +class ModelRunnerBase: + + def __init__( + self, + vllm_config: VllmConfig, + device: torch.device, + ): + self.vllm_config = vllm_config + self.model_config = vllm_config.model_config + self.cache_config = vllm_config.cache_config + self.lora_config = vllm_config.lora_config + self.load_config = vllm_config.load_config + self.parallel_config = vllm_config.parallel_config + self.scheduler_config = vllm_config.scheduler_config + self.speculative_config = vllm_config.speculative_config + self.prompt_adapter_config = vllm_config.prompt_adapter_config + self.observability_config = vllm_config.observability_config + self.device_config = vllm_config.device_config + + model_config = self.model_config + cache_config = self.cache_config + scheduler_config = self.scheduler_config + parallel_config = self.parallel_config + self.device = device + self.pin_memory = is_pin_memory_available() + self.dtype = self.model_config.dtype + + self.is_multimodal_model = model_config.is_multimodal_model + self.sliding_window = model_config.get_sliding_window() + self.block_size = cache_config.block_size + self.max_model_len = model_config.max_model_len + self.max_num_blocks_per_req = cdiv(self.max_model_len, self.block_size) + self.max_num_tokens = scheduler_config.max_num_batched_tokens + self.max_num_reqs = scheduler_config.max_num_seqs + + # Model-related. + self.num_attn_layers = model_config.get_num_layers_by_block_type( + parallel_config, LayerBlockType.attention) + self.num_query_heads = model_config.get_num_attention_heads( + parallel_config) + self.num_kv_heads = model_config.get_num_kv_heads(parallel_config) + self.head_size = model_config.get_head_size() + self.hidden_size = model_config.get_hidden_size() + + self.model: Optional[nn.Module] = None + + def get_model(self) -> nn.Module: + assert self.model is not None + return self.model + + def execute_model( + self, + scheduler_output: "SchedulerOutput", + ) -> ModelRunnerOutput: + raise NotImplementedError() + + def load_model(self) -> None: + raise NotImplementedError() + + def dummy_run( + self, + kv_caches, + num_tokens: int, + seq_len: Optional[int] = None, + exec_mode: Optional[ExecutionMode] = None, + ) -> torch.Tensor: + raise NotImplementedError() + + def profile_run(self) -> None: + raise NotImplementedError() + + def capture_model(self) -> None: + raise NotImplementedError() + + def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: + raise NotImplementedError() + + def get_kv_cache_spec(self) -> KVCacheSpec: + """ + Generates the KVCacheSpec by parsing the kv cache format from each + Attention module in the static forward context. + Returns: + KVCacheSpec: A dictionary mapping layer names to their KV cache + format. Layers that do not need KV cache are not included. + """ + + forward_ctx = self.vllm_config.compilation_config.static_forward_context + block_size = self.vllm_config.cache_config.block_size + kv_cache_spec: KVCacheSpec = {} + for layer_name, attn_module in forward_ctx.items(): + # TODO: Support other attention modules, e.g., sliding window, + # cross-attention, MLA. + assert isinstance(attn_module, Attention) + if attn_module.attn_type == AttentionType.DECODER: + kv_cache_spec[layer_name] = FullAttentionSpec( + block_size=block_size, + num_kv_heads=attn_module.num_kv_heads, + head_size=attn_module.head_size, + dtype=attn_module.dtype, + ) + elif attn_module.attn_type in (AttentionType.ENCODER, + AttentionType.ENCODER_ONLY): + # encoder-only attention does not need KV cache. + continue + elif attn_module.attn_type == AttentionType.ENCODER_DECODER: + raise NotImplementedError + else: + raise ValueError( + f"Unknown attention type: {attn_module.attn_type}") + + return kv_cache_spec diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py new file mode 100644 index 0000000000000..7be9bd27aeb4f --- /dev/null +++ b/vllm/v1/worker/tpu_model_runner.py @@ -0,0 +1,778 @@ +import time +from dataclasses import dataclass +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, cast +from unittest.mock import patch + +import torch +import torch.distributed +import torch.nn as nn +# TPU XLA related +import torch_xla.core.xla_model as xm +import torch_xla.runtime as xr + +from vllm.attention import AttentionMetadata +from vllm.config import VllmConfig +from vllm.forward_context import set_forward_context +from vllm.logger import init_logger +from vllm.model_executor.model_loader import get_model +from vllm.sampling_params import SamplingType +from vllm.v1.attention.backends.pallas import (PallasAttentionBackend, + PallasMetadata) +from vllm.v1.kv_cache_interface import FullAttentionSpec, KVCacheConfig +from vllm.v1.outputs import ModelRunnerOutput +from vllm.v1.utils import bind_kv_cache +from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch +from vllm.v1.worker.model_runner_base import ExecutionMode, ModelRunnerBase + +if TYPE_CHECKING: + from vllm.v1.core.scheduler import SchedulerOutput + +logger = init_logger(__name__) + +# Here we utilize the behavior that out-of-bound index is ignored. +# FIXME(woosuk): Find a more reliable way to prevent possible bugs. +_PAD_SLOT_ID = 1_000_000_000 +# FIXME(woosuk): Temporarily disabled top-p sampling since it's too slow. +_ENABLE_TOP_P = False +# FIXME(woosuk): A temporary hack to support `n > 1`. +# This can significantly affect the performance if too large. +_MAX_NUM_SAMPLES = 128 + + +@dataclass +class PrefillInputData: + + request_ids: List + prompt_lens: List + token_ids: List + position_ids: List + attn_metadata: List + + def zipped(self): + return zip(self.request_ids, self.prompt_lens, self.token_ids, + self.position_ids, self.attn_metadata) + + +@dataclass +class DecodeInputData: + + num_decodes: int + token_ids: Optional[torch.Tensor] = None + position_ids: Optional[torch.Tensor] = None + attn_metadata: Optional[PallasMetadata] = None + + +class TPUModelRunner(ModelRunnerBase): + + def __init__( + self, + vllm_config: VllmConfig, + device: torch.device, + ): + super().__init__(vllm_config, device) + + # Persistent batch. + self.input_batch = InputBatch( + max_num_reqs=self.max_num_reqs, + max_model_len=self.max_model_len, + max_num_blocks_per_req=self.max_num_blocks_per_req, + device=self.device, + pin_memory=self.pin_memory, + vocab_size=self.model_config.get_vocab_size(), + ) + + # Request states. + self.requests: Dict[str, CachedRequestState] = {} + + # KV caches for forward pass + self.kv_caches: List[Tuple[torch.Tensor, torch.Tensor]] = [] + + # Used to initialize positions for the individual prefills + self.prefill_positions = torch.tensor(range(self.max_model_len), + device="cpu", + dtype=torch.int32).reshape( + 1, -1) + + # Used to indicate how many prefills there are for each scheduler + # iteration + self.num_new_reqs: int = 0 + + def _update_states(self, scheduler_output: "SchedulerOutput") -> None: + # Remove stopped requests from the cached states. + # Keep the states of the pre-empted requests. + for req_id in scheduler_output.finished_req_ids: + self.requests.pop(req_id, None) + + # Remove the requests from the persistent batch. + stopped_req_ids = set().union( + scheduler_output.preempted_req_ids, + scheduler_output.finished_req_ids, + ) + removed_req_indices: List[int] = [] + for req_id in stopped_req_ids: + req_index = self.input_batch.remove_request(req_id) + if req_index is not None: + removed_req_indices.append(req_index) + + # Update the states of the running requests. + for req_data in scheduler_output.scheduled_running_reqs: + req_id = req_data.req_id + req_state = self.requests[req_id] + req_index = self.input_batch.req_id_to_index[req_id] + + # Update the num_computed_tokens. + req_state.num_computed_tokens = req_data.num_computed_tokens + self.input_batch.num_computed_tokens_cpu[req_index] = ( + req_data.num_computed_tokens) + + # Update the block table. + num_new_blocks = len(req_data.new_block_ids) + if num_new_blocks == 0: + continue + start_index = len(req_state.block_ids) + req_state.block_ids.extend(req_data.new_block_ids) + self.input_batch.block_table.append_row(req_index, start_index, + req_data.new_block_ids) + + req_ids_to_add: List[str] = [] + # Add new requests to the cached states. + for new_req_data in scheduler_output.scheduled_new_reqs: + req_id = new_req_data.req_id + sampling_params = new_req_data.sampling_params + if sampling_params.sampling_type == SamplingType.RANDOM_SEED: + generator = torch.Generator(device=self.device) + generator.manual_seed(sampling_params.seed) + else: + generator = None + + self.requests[req_id] = CachedRequestState( + req_id=req_id, + prompt_token_ids=new_req_data.prompt_token_ids, + prompt=new_req_data.prompt, + mm_inputs=new_req_data.mm_inputs, + mm_positions=new_req_data.mm_positions, + sampling_params=sampling_params, + generator=generator, + block_ids=new_req_data.block_ids, + num_computed_tokens=new_req_data.num_computed_tokens, + output_token_ids=[], + ) + req_ids_to_add.append(req_id) + + # Update the cached states of the resumed requests. + for res_req_data in scheduler_output.scheduled_resumed_reqs: + req_id = res_req_data.req_id + req_state = self.requests[req_id] + + req_state.block_ids = res_req_data.block_ids + req_state.num_computed_tokens = res_req_data.num_computed_tokens + req_ids_to_add.append(req_id) + + # For TPU, we keep all of the decode requests before the + # prefill requests in the batch sequence. + # 1. First condense, so all decodes move to start + # 2. Then add new prefills to the end of the batch + removed_req_indices = sorted(removed_req_indices, reverse=True) + if removed_req_indices: + self.input_batch.condense(removed_req_indices) + + for req_id in req_ids_to_add: + req_state = self.requests[req_id] + self.input_batch.add_request(req_state, None) # Append last + + self.num_new_reqs = len(req_ids_to_add) + + def _prepare_prefill_inputs( + self, + num_scheduled_tokens: List[int], + ) -> PrefillInputData: + # Each prefill run separately with shape [1, padded_prompt_len]. + # So we create lists that will be used in execute_model(). + + prefill_request_ids = [] + prefill_prompt_lens = [] + prefill_token_ids = [] + prefill_position_ids = [] + prefill_attn_metadata = [] + + # DECODES are the first num_decodes REQUESTS. + # PREFILLS are the next num_reqs - num_decodes REQUESTS. + num_reqs = self.input_batch.num_reqs + num_decodes = num_reqs - self.num_new_reqs + for idx in range(num_decodes, num_reqs): + req_id = self.input_batch.req_ids[idx] + prefill_request_ids.append(req_id) + + prompt_len = num_scheduled_tokens[idx] + prefill_prompt_lens.append(prompt_len) + + # STATIC SHAPE: prefills are padded to the next power of 2. + padded_prompt_len = _get_padded_prefill_len(prompt_len) + assert padded_prompt_len <= self.max_model_len + + # TOKEN_IDS. + token_ids = torch.from_numpy(self.input_batch.token_ids_cpu[ + idx, :padded_prompt_len].reshape(1, -1)) + token_ids[:, prompt_len:] = 0 + prefill_token_ids.append(token_ids.to(self.device)) + + # POSITIONS. + positions = self.prefill_positions[:, :padded_prompt_len].clone() + positions[:, prompt_len:] = 0 + prefill_position_ids.append(positions.to(self.device)) + + # SLOT_MAPPING. + # The "slot" is the "physical index" of a token in the KV cache. + # Look up the block_idx in the block table (logical<>physical map) + # to compute this. + block_table_cpu_tensor = \ + self.input_batch.block_table.get_cpu_tensor() + + block_numbers = block_table_cpu_tensor[idx, positions // + self.block_size].reshape( + 1, -1) + + block_offsets = positions % self.block_size + slot_mapping = block_numbers * self.block_size + block_offsets + # Set an out of range value for the padding tokens so that they + # are ignored when inserting into the KV cache. + slot_mapping[:, prompt_len:] = _PAD_SLOT_ID + slot_mapping = slot_mapping.long() + + prefill_attn_metadata.append( + PallasMetadata( + num_prefills=1, + num_prefill_tokens=0, # NOTE: This is not used. + num_decode_tokens=0, + slot_mapping=slot_mapping.to(self.device), + multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=True, + block_tables=None, + context_lens=None, + effective_query_lens=None, + )) + + return PrefillInputData( + request_ids=prefill_request_ids, + prompt_lens=prefill_prompt_lens, + token_ids=prefill_token_ids, + position_ids=prefill_position_ids, + attn_metadata=prefill_attn_metadata, + ) + + def _prepare_decode_inputs(self) -> DecodeInputData: + # Decodes run as one single padded batch with shape [batch, 1] + # + # We need to set _PAD_SLOT_ID for the padding tokens in the + # slot_mapping, such that the attention KV cache insertion + # logic knows to ignore those indices. Otherwise, the + # padding data can be dummy since we have a causal mask. + + # DECODES are the first num_decodes REQUESTS. + # PREFILLS are the next num_reqs - num_decodes REQUESTS. + num_reqs = self.input_batch.num_reqs + num_decodes = num_reqs - self.num_new_reqs + + if num_decodes == 0: + return DecodeInputData(num_decodes=0) + + # PAD FOR STATIC SHAPES. + padded_batch_size = _get_padded_batch_size(num_decodes) + + # POSITIONS. [batch, 1] + # We slice at the end, since we use the positions for gathering. + positions = torch.from_numpy( + self.input_batch.num_computed_tokens_cpu.reshape(-1, 1)) + index = positions.to(torch.int64) + index[num_decodes:] = 0 + positions = positions[:padded_batch_size] + positions[num_decodes:] = 0 + + # TOKEN_IDS. [batch, 1] + token_ids = torch.gather( + input=torch.from_numpy(self.input_batch.token_ids_cpu), + dim=1, + index=index, + )[:padded_batch_size].to(torch.int32) + token_ids[num_decodes:] = 0 + + # SLOT_MAPPING [batch, 1] + # The "slot" is the "physical index" of a token in the KV cache. + # Look up the block_idx in the block table (logical<>physical map) + # to compute this. + block_table_cpu_tensor = self.input_batch.block_table.get_cpu_tensor() + block_number = torch.gather(input=block_table_cpu_tensor, + dim=1, + index=(index // self.block_size)) + block_offsets = index % self.block_size + slot_mapping = block_number * self.block_size + block_offsets + # Set an out of range value for the padding tokens so that they + # are ignored when inserting into the KV cache. + slot_mapping[num_decodes:] = _PAD_SLOT_ID + slot_mapping = slot_mapping[:padded_batch_size] + slot_mapping = slot_mapping.long() + + # BLOCK_TABLE [batch, max_num_blocks_per_req] + block_table = block_table_cpu_tensor[:padded_batch_size] + + # CONTEXT_LENS [batch_size] + context_lens = (positions.reshape(-1) + 1) + context_lens[num_decodes:] = 0 + + # CPU<>TPU sync happens here. + return DecodeInputData(num_decodes=num_decodes, + token_ids=token_ids.to(self.device), + position_ids=positions.to(self.device), + attn_metadata=PallasMetadata( + num_prefills=0, + num_prefill_tokens=0, + num_decode_tokens=padded_batch_size, + slot_mapping=slot_mapping.to(self.device), + multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=True, + block_tables=block_table.to(self.device), + context_lens=context_lens.to(self.device), + effective_query_lens=None, + )) + + def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): + total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens + assert total_num_scheduled_tokens > 0 + num_reqs = self.input_batch.num_reqs + assert num_reqs > 0 + + num_decodes = num_reqs - self.num_new_reqs + + # TODO: Resurrect + # OPTIMIZATION: Start copying the block table first. + # This way, we can overlap the copy with the following CPU operations. + # TODO: Verify this works with TPUs + # self.input_batch.block_table.commit(num_reqs) + + # Get the number of scheduled tokens for each request. + # TODO: The Python loop can be slow. Optimize. + num_scheduled_tokens = [] + max_num_scheduled_tokens = 0 + for idx, req_id in enumerate(self.input_batch.req_ids[:num_reqs]): + assert req_id is not None + num_tokens = scheduler_output.num_scheduled_tokens[req_id] + num_scheduled_tokens.append(num_tokens) + max_num_scheduled_tokens = max(max_num_scheduled_tokens, + num_tokens) + + # NOTE: Assert that all the decodes are "decodes". + if idx < num_decodes: + assert num_tokens == 1 + + assert max_num_scheduled_tokens > 0 + + return ( + self._prepare_prefill_inputs(num_scheduled_tokens), + self._prepare_decode_inputs(), + ) + + @torch.no_grad() + def execute_model( + self, + scheduler_output: "SchedulerOutput", + ) -> ModelRunnerOutput: + self._update_states(scheduler_output) + + # Prepare the decoder inputs. + prefill_data, decode_data = self._prepare_inputs(scheduler_output) + + num_reqs = self.input_batch.num_reqs + + ######################### DECODES ######################### + # Decodes run as one single batch with [padded_batch, 1] + sampled_token_ids_list = [] + if decode_data.num_decodes > 0: + # FORWARD. + with set_forward_context(decode_data.attn_metadata, + self.vllm_config): + assert self.model is not None + selected_token_ids = self.model(decode_data.token_ids, + decode_data.position_ids, + decode_data.attn_metadata, + self.kv_caches) + + # NOTE: TPU<>CPU sync happens here. + # We need to call .cpu() first to avoid recompilation. + token_ids = selected_token_ids.cpu()[:decode_data.num_decodes] + sampled_token_ids_list.extend(token_ids.tolist()) + + # UPDATE REQUEST STATE. + for i, req_id in enumerate( + self.input_batch.req_ids[:decode_data.num_decodes]): + assert req_id is not None + req_state = self.requests[req_id] + + assert scheduler_output.num_scheduled_tokens[req_id] == 1 + seq_len = (req_state.num_computed_tokens + + scheduler_output.num_scheduled_tokens[req_id]) + assert seq_len == req_state.num_tokens + + token_id = sampled_token_ids_list[i] + self.input_batch.token_ids_cpu[i, seq_len] = token_id + self.input_batch.num_tokens[i] += 1 + req_state.output_token_ids.append(token_id) + + ######################### PREFILLS ######################### + # Prefills run separately with shape [1, padded_prefill_len], + # due to lack of variable length attention kernel so far. + for (req_id, prompt_len, token_ids, position_ids, + attn_metadata) in prefill_data.zipped(): + assert req_id is not None + + # FORWARD. + with set_forward_context(attn_metadata, self.vllm_config): + assert self.model is not None + selected_token_ids = self.model(token_ids, position_ids, + attn_metadata, self.kv_caches) + + # NOTE: TPU<>CPU sync happens here. + # We need to call .cpu() first to avoid recompilation. + token_id = selected_token_ids.cpu()[prompt_len - 1].item() + sampled_token_ids_list.append(token_id) + req_state = self.requests[req_id] + + assert req_state.num_computed_tokens == 0 + seq_len = (req_state.num_computed_tokens + + scheduler_output.num_scheduled_tokens[req_id]) + assert seq_len == req_state.num_tokens + assert prompt_len == seq_len + + # UPDATE REQUEST STATE. + req_idx = self.input_batch.req_id_to_index[req_id] + self.input_batch.token_ids_cpu[req_idx, seq_len] = token_id + self.input_batch.num_tokens[req_idx] += 1 + req_state.output_token_ids.append(token_id) + + # num_reqs entries should be non-None + assert all( + req_id is not None for req_id in + self.input_batch.req_ids[:num_reqs]), "req_ids contains None" + req_ids = cast(List[str], self.input_batch.req_ids[:num_reqs]) + + model_runner_output = ModelRunnerOutput( + req_ids=req_ids, + req_id_to_index=self.input_batch.req_id_to_index, + sampled_token_ids=sampled_token_ids_list, + logprob_token_ids_cpu=None, + logprobs_cpu=None, + ) + + return model_runner_output + + def load_model(self) -> None: + self.device = self.device_config.device + + # NOTE(woosuk): While the executor assigns the TP ranks to the worker + # process, the ranks can be different from the ranks internally assigned + # by the xm runtime. Therefore, there is a mismatch in the rank + # assignment between the gloo (cpu) runtime and the xm (tpu) runtime. + # This is not a problem in linear layers because all-reduce is + # rank-agnostic. However, it matters for all-gather as the ranks + # determine the order of concatenating the output tensors. + # As a workaround, we use the xm's rank assignment only when loading + # the embedding weights. + xm_tp_rank = xr.global_ordinal() + with patch( + "vllm.model_executor.layers.vocab_parallel_embedding." + "get_tensor_model_parallel_rank", + return_value=xm_tp_rank): + model = get_model(vllm_config=self.vllm_config) + model = model.eval() + xm.wait_device_ops() + model = ModelWrapperV1(model) + self.model = torch.compile(model, + backend="openxla", + fullgraph=True, + dynamic=False) + + def dummy_run( + self, + kv_caches, + num_tokens: int, + seq_len: Optional[int] = None, + exec_mode: Optional[ExecutionMode] = None, + ) -> None: + assert seq_len is not None + assert exec_mode is not None + + exec_mode = ExecutionMode(exec_mode) + if exec_mode.is_prefill(): + seq_len = (seq_len + 15) // 16 * 16 + token_ids = torch.zeros((num_tokens, seq_len), + dtype=torch.int32, + device=self.device) + position_ids = torch.zeros((num_tokens, seq_len), + dtype=torch.int32, + device=self.device) + slot_mapping = torch.zeros((num_tokens, seq_len), + dtype=torch.int64, + device=self.device) + if exec_mode == ExecutionMode.PREFILL: + attn_metadata = PallasMetadata( + num_prefills=num_tokens, + num_prefill_tokens=num_tokens * seq_len, + num_decode_tokens=0, + slot_mapping=slot_mapping, + multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=True, + block_tables=None, + context_lens=None, + effective_query_lens=None, + ) + + else: + context_lens = torch.ones((num_tokens, ), + dtype=torch.int32, + device=self.device) + + block_tables = torch.zeros( + (num_tokens, self.max_num_blocks_per_req), + dtype=torch.int32, + device=self.device) + + effective_query_lens = torch.ones_like(context_lens) + + attn_metadata = PallasMetadata( + num_prefills=num_tokens, + num_prefill_tokens=num_tokens * seq_len, + num_decode_tokens=0, + slot_mapping=slot_mapping, + multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=True, + block_tables=block_tables, + context_lens=context_lens, + effective_query_lens=effective_query_lens, + ) + else: + assert seq_len == 1 + token_ids = torch.zeros((num_tokens, seq_len), + dtype=torch.int32, + device=self.device) + position_ids = torch.zeros((num_tokens, seq_len), + dtype=torch.int32, + device=self.device) + slot_mapping = torch.zeros((num_tokens, seq_len), + dtype=torch.int64, + device=self.device) + block_tables = torch.zeros( + (num_tokens, self.max_num_blocks_per_req), + dtype=torch.int32, + device=self.device) + context_lens = torch.ones((num_tokens, ), + dtype=torch.int32, + device=self.device) + attn_metadata = PallasMetadata( + num_prefills=0, + num_prefill_tokens=0, + num_decode_tokens=num_tokens * seq_len, + slot_mapping=slot_mapping, + multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=True, + block_tables=block_tables, + context_lens=context_lens, + ) + + # NOTE(woosuk): There are two stages of compilation: torch.compile and + # XLA compilation. Using `mark_dynamic` can reduce the torch.compile + # overhead by reusing the FX graph for different shapes. + # However, the XLA graph will still require static shapes and needs to + # be re-compiled for every different shapes. This overhead is inevitable + # in the first run, but can be skipped afterwards as we cache the XLA + # graphs in the disk (VLLM_XLA_CACHE_PATH). + if exec_mode.is_prefill(): + # Prefll + torch._dynamo.mark_dynamic(token_ids, 1) + torch._dynamo.mark_dynamic(position_ids, 1) + torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 1) + else: + # Decode + torch._dynamo.mark_dynamic(token_ids, 0) + torch._dynamo.mark_dynamic(position_ids, 0) + torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 0) + torch._dynamo.mark_dynamic(attn_metadata.context_lens, 0) + torch._dynamo.mark_dynamic(attn_metadata.block_tables, 0) + + # TODO: Remove the attn_metadata above + with set_forward_context(None, self.vllm_config): + assert self.model is not None + self.model(token_ids, position_ids, None, kv_caches) + + def profile_run(self) -> None: + raise NotImplementedError() + + def capture_model(self) -> None: + """Compile the model.""" + + logger.info("Compiling the model with different input shapes.") + + # Capture prefill shapes + start = time.perf_counter() + for batch_size in [1]: + seq_len = 16 + while True: + self.dummy_run(self.kv_caches, batch_size, seq_len, + ExecutionMode.PREFILL) + xm.wait_device_ops() + logger.info(" -- batch_size: %d, seq_len: %d", batch_size, + seq_len) + + if seq_len >= self.model_config.max_model_len: + break + + num_tokens = batch_size * seq_len + if num_tokens >= self.scheduler_config.max_num_batched_tokens: + break + + # Move to next seq_len + seq_len = seq_len * 2 + + end = time.perf_counter() + logger.info("Compilation for prefill shapes is done in %.2f [secs].", + end - start) + + # Capture decode shapes. + start = time.time() + seq_len = 1 + batch_size = 8 # Must be in sync with _get_padded_batch_size() + while True: + self.dummy_run(self.kv_caches, batch_size, seq_len, + ExecutionMode.DECODE) + xm.wait_device_ops() + logger.info(" -- batch_size: %d, seq_len: %d, max_num_seqs = %d", + batch_size, seq_len, + self.scheduler_config.max_num_seqs) + + if batch_size >= self.scheduler_config.max_num_seqs: + break + + batch_size = batch_size + 16 if batch_size >= 16 else batch_size * 2 + + end = time.time() + logger.info("Compilation for decode shapes is done in %.2f [secs].", + end - start) + + def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: + """ + Initialize KV cache based on `kv_cache_config`. + Args: + kv_cache_config: Configuration for the KV cache, including the KV + cache size of each layer + """ + if len(kv_cache_config.groups) > 1: + raise NotImplementedError( + "Hybrid models with more than one KV cache type are not " + "supported yet.") + + kv_caches: Dict[str, torch.Tensor] = {} + + for layer_name, layer_spec in kv_cache_config.kv_cache_spec.items(): + tensor_config = kv_cache_config.tensors[layer_name] + assert tensor_config.size % layer_spec.page_size_bytes == 0 + num_blocks = tensor_config.size // layer_spec.page_size_bytes + if isinstance(layer_spec, FullAttentionSpec): + kv_cache_shape = PallasAttentionBackend.get_kv_cache_shape( + num_blocks, layer_spec.block_size, layer_spec.num_kv_heads, + layer_spec.head_size) + dtype = layer_spec.dtype + + tpu_k_cache = torch.zeros(kv_cache_shape, + dtype=dtype, + device=self.device) + tpu_v_cache = torch.zeros_like(tpu_k_cache) + + kv_caches[layer_name] = (tpu_k_cache, tpu_v_cache) + else: + raise NotImplementedError + + bind_kv_cache( + kv_caches, + self.vllm_config.compilation_config.static_forward_context, + self.kv_caches) + + +class ModelWrapperV1(nn.Module): + + def __init__(self, model: nn.Module): + super().__init__() + self.model = model + + def forward( + self, + token_ids: torch.Tensor, + position_ids: torch.Tensor, + attn_metadata: AttentionMetadata, + kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], + ) -> torch.Tensor: + """Executes the forward pass of the model and samples the next token. + + Args: + token_ids: The input token IDs of shape [batch_size, seq_len]. + position_ids: The input position IDs of shape [batch_size, seq_len]. + attn_metadata: The Pallas attention metadata. + input_lens: The actual input lengths of shape [batch_size]. + t: The sampling temperature of shape [batch_size]. + p: The top-p probability of shape [batch_size]. + num_samples: Number of samples to draw from each logits vector. + kv_caches: The key and value caches. They can be None during the + memory profiling at initialization. + """ + # Skip this in memory profiling at initialization. + if attn_metadata is not None: + # index_copy_(slot_mapping) only works when the inserted dimension + # is 0. However, the KV cache in the Pallas backend has the shape + # [num_kv_heads, num_blocks, block_size, head_size]. To make it + # work, we need to flatten the first three dimensions and modify + # the slot_mapping accordingly. + num_kv_heads, num_blocks, block_size, _ = kv_caches[0][0].shape + slot_mapping = attn_metadata.slot_mapping + slot_mapping = slot_mapping.flatten() + head_indicies = torch.arange(0, + num_kv_heads, + device=slot_mapping.device, + dtype=slot_mapping.dtype) + head_indicies *= block_size * num_blocks + slot_mapping = slot_mapping.repeat_interleave(num_kv_heads).view( + -1, num_kv_heads) + slot_mapping = slot_mapping + head_indicies.view(1, -1) + slot_mapping = slot_mapping.flatten() + attn_metadata.slot_mapping = slot_mapping + + assert self.model is not None + hidden_states = self.model( + token_ids, + position_ids, + kv_caches, + attn_metadata, + ) + + hidden_states = hidden_states.flatten(0, 1) + logits = self.model.compute_logits(hidden_states, None) + + # Greedy sampling. + argmax_token_ids = torch.argmax(logits, dim=-1, keepdim=True) + argmax_token_ids = argmax_token_ids.squeeze(dim=-1) + return argmax_token_ids + + +def _get_padded_prefill_len(x: int) -> int: + # NOTE(woosuk): The pallas FlashAttention kernel requires the sequence + # length to be a multiple of 16. We pad the prompt length to the nearest + # multiple of 16. This is also good for performance. + if x <= 16: + return 16 + return 1 << (x - 1).bit_length() + + +def _get_padded_batch_size(batch_size: int) -> int: + # The GMM Pallas kernel requires num_tokens * topk to be a multiple of 16. + # To meet this requirement in the simplest way, we set the minimal batch + # size to 8. + if batch_size <= 8: + return 8 + else: + return ((batch_size + 15) // 16) * 16 diff --git a/vllm/v1/worker/tpu_worker.py b/vllm/v1/worker/tpu_worker.py new file mode 100644 index 0000000000000..7d76c82bd7c63 --- /dev/null +++ b/vllm/v1/worker/tpu_worker.py @@ -0,0 +1,141 @@ +"""A TPU worker class.""" +import os +from typing import Optional + +import torch +import torch.distributed +import torch_xla.core.xla_model as xm +import torch_xla.runtime as xr + +import vllm.envs as envs +from vllm.config import ParallelConfig, VllmConfig +from vllm.distributed import (ensure_model_parallel_initialized, + init_distributed_environment) +from vllm.logger import init_logger +from vllm.model_executor import set_random_seed +from vllm.v1.core.scheduler import SchedulerOutput +from vllm.v1.outputs import ModelRunnerOutput +from vllm.v1.worker.tpu_model_runner import ExecutionMode, TPUModelRunner +from vllm.v1.worker.worker_base import WorkerBase + +logger = init_logger(__name__) + + +class TPUWorker(WorkerBase): + + def __init__( + self, + vllm_config: VllmConfig, + local_rank: int, + rank: int, + distributed_init_method: str, + is_driver_worker: bool = False, + ): + super().__init__(vllm_config, local_rank, rank, + distributed_init_method) + + def init_device(self): + os.environ["PJRT_DEVICE"] = "TPU" + torch.set_grad_enabled(False) + torch.set_default_dtype(self.model_config.dtype) + + # Initialize the distributed environment. + init_tpu_worker_distributed_environment(self.parallel_config, + self.rank, + self.distributed_init_method, + self.local_rank) + + # Device initialization should happen after initializing + # the distributed runtime. + self.device = xm.xla_device() + self.device_config.device = self.device + + # Set random seed. + set_random_seed(self.model_config.seed) + xm.set_rng_state(self.model_config.seed, self.device) + + # Increase the cache size limit, which is the maximum number of + # dynamo graphs that can be compiled. + # NOTE(woosuk): Usually, we compile 10-15 graphs for prefill and + # 30-40 graphs for decode. 128 is an arbitrary safe number. + torch._dynamo.config.cache_size_limit = 128 + # Use persistent cache to avoid XLA recompilation. + # NOTE(woosuk): Set per-rank cache path since different ranks + # can have slightly different XLA graphs. + world_size = self.parallel_config.world_size + rank = xr.global_ordinal() + per_rank_path = os.path.join(envs.VLLM_XLA_CACHE_PATH, + f"tp{world_size}_rank{rank}") + xr.initialize_cache(per_rank_path, readonly=False) + + # Init ModelRunner here, so that we have access to self.device. + self.model_runner = TPUModelRunner(self.vllm_config, self.device) + + def determine_available_memory(self) -> int: + assert self.model_runner is not None + + num_layers = self.model_config.get_num_layers(self.parallel_config) + + # use an empty tensor instead of `None`` to force Dynamo to pass + # it by reference, rather by specializing on the value ``None``. + # the `dtype` argument does not matter, and we use `float32` as + # a placeholder (it has wide hardware support). + kv_caches = [(torch.tensor([], dtype=torch.float32, + device=self.device), + torch.tensor([], dtype=torch.float32, + device=self.device)) + for _ in range(num_layers)] + + self.model_runner.dummy_run( + kv_caches, + num_tokens=1, + seq_len=self.scheduler_config.max_num_batched_tokens, + exec_mode=ExecutionMode.PREFILL, + ) + + # Synchronize before measuring the memory usage. + xm.wait_device_ops() + + # Get the maximum amount of memory used by the model weights and + # intermediate activations. + m = xm.get_memory_info(self.device) + total_memory_size = m["bytes_limit"] + profiled = m["peak_bytes_used"] # Weights + intermediate activations. + + # Calculate the TPU KV cache size based on profiling. + usable_memory_size = int(total_memory_size * + self.cache_config.gpu_memory_utilization) + tpu_kv_cache_bytes = max(usable_memory_size - profiled, 0) + + return int(tpu_kv_cache_bytes) + + def execute_model( + self, + scheduler_output: "SchedulerOutput", + ) -> Optional[ModelRunnerOutput]: + assert self.model_runner is not None + output = self.model_runner.execute_model(scheduler_output) + return output if self.rank == 0 else None + + +def init_tpu_worker_distributed_environment( + parallel_config: ParallelConfig, + rank: int, + distributed_init_method: Optional[str] = None, + local_rank: int = -1, +) -> None: + """Initialize the distributed environment.""" + + # NOTE(woosuk): This is just to initialize the TP group and broadcast + # the input objects on CPU. The all-reduce and all-gather ops on TPU + # are invoked by `xm.all_reduce` and `xm.all_gather` which use their + # own context. + init_distributed_environment( + world_size=parallel_config.world_size, + rank=rank, + local_rank=local_rank, + distributed_init_method=distributed_init_method, + backend="gloo", + ) + ensure_model_parallel_initialized(parallel_config.tensor_parallel_size, + parallel_config.pipeline_parallel_size) diff --git a/vllm/v1/worker/worker_base.py b/vllm/v1/worker/worker_base.py new file mode 100644 index 0000000000000..d72f67b3a4d77 --- /dev/null +++ b/vllm/v1/worker/worker_base.py @@ -0,0 +1,173 @@ +"""A GPU worker class.""" +from typing import TYPE_CHECKING, Optional + +import torch +import torch.distributed +import torch.nn as nn + +import vllm.envs as envs +from vllm.config import CacheConfig, ModelConfig, ParallelConfig, VllmConfig +from vllm.logger import init_logger +from vllm.model_executor import set_random_seed +from vllm.platforms import current_platform +from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, LayerBlockType, get_dtype_size +from vllm.v1.core.scheduler import SchedulerOutput +from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec +from vllm.v1.outputs import ModelRunnerOutput +from vllm.v1.worker.model_runner_base import ModelRunnerBase + +logger = init_logger(__name__) + +if TYPE_CHECKING: + from vllm.v1.core.scheduler import SchedulerOutput + + +class WorkerBase: + + def __init__( + self, + vllm_config: VllmConfig, + local_rank: int, + rank: int, + distributed_init_method: str, + is_driver_worker: bool = False, + ): + self.vllm_config = vllm_config + self.model_config = vllm_config.model_config + self.cache_config = vllm_config.cache_config + self.lora_config = vllm_config.lora_config + self.load_config = vllm_config.load_config + self.parallel_config = vllm_config.parallel_config + self.scheduler_config = vllm_config.scheduler_config + self.device_config = vllm_config.device_config + self.speculative_config = vllm_config.speculative_config + self.prompt_adapter_config = vllm_config.prompt_adapter_config + self.observability_config = vllm_config.observability_config + + self.parallel_config.rank = rank + self.local_rank = local_rank + self.rank = rank + self.distributed_init_method = distributed_init_method + + if self.cache_config.cache_dtype == "auto": + self.cache_dtype = self.model_config.dtype + else: + self.cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[ + self.cache_config.cache_dtype] + + if self.model_config.trust_remote_code: + # note: lazy import to avoid importing torch before initializing + from vllm.utils import init_cached_hf_modules + init_cached_hf_modules() + + # Torch profiler. Enabled and configured through env vars: + # VLLM_TORCH_PROFILER_DIR=/path/to/save/trace + if envs.VLLM_TORCH_PROFILER_DIR: + torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR + logger.info("Profiling enabled. Traces will be saved to: %s", + torch_profiler_trace_dir) + self.profiler = torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ], + with_stack=True, + on_trace_ready=torch.profiler.tensorboard_trace_handler( + torch_profiler_trace_dir, use_gzip=True)) + else: + self.profiler = None + + # Initialized by the specific platform + self.model_runner: Optional[ModelRunnerBase] = None + + def init_device(self): + raise NotImplementedError() + + def load_model(self) -> None: + assert self.model_runner is not None + self.model_runner.load_model() + + def determine_available_memory(self) -> int: + raise NotImplementedError() + + def compile_or_warm_up_model(self) -> None: + assert self.model_runner is not None + + if not self.model_config.enforce_eager: + self.model_runner.capture_model() + + # Reset the seed to ensure that the random state is not affected by + # the model initialization and profiling. + set_random_seed(self.model_config.seed) + + def get_model(self) -> nn.Module: + assert self.model_runner is not None + return self.model_runner.get_model() + + def get_kv_cache_spec(self) -> KVCacheSpec: + assert self.model_runner is not None + return self.model_runner.get_kv_cache_spec() + + def initialize_cache(self, kv_cache_config: KVCacheConfig) -> None: + """Allocate GPU KV cache with the specified kv_cache_config.""" + assert self.model_runner is not None + self.model_runner.initialize_kv_cache(kv_cache_config) + + def execute_model( + self, + scheduler_output: "SchedulerOutput", + ) -> Optional[ModelRunnerOutput]: + raise NotImplementedError() + + def profile(self, is_start: bool = True): + if self.profiler is None: + raise RuntimeError("Profiler is not enabled.") + if is_start: + self.profiler.start() + else: + self.profiler.stop() + + def check_health(self) -> None: + # worker will always be healthy as long as it's running. + return + + +def check_if_gpu_supports_dtype(torch_dtype: torch.dtype): + # Check if the GPU supports the dtype. + if torch_dtype == torch.bfloat16: # noqa: SIM102 + if not current_platform.has_device_capability(80): + capability = current_platform.get_device_capability() + gpu_name = current_platform.get_device_name() + + if capability is None: + compute_str = "does not have a compute capability" + else: + version_str = capability.as_version_str() + compute_str = f"has compute capability {version_str}" + + raise ValueError( + "Bfloat16 is only supported on GPUs with compute capability " + f"of at least 8.0. Your {gpu_name} GPU {compute_str}. " + "You can use float16 instead by explicitly setting the" + "`dtype` flag in CLI, for example: --dtype=half.") + + +def get_cache_block_size( + cache_config: CacheConfig, + model_config: ModelConfig, + parallel_config: ParallelConfig, +) -> int: + head_size = model_config.get_head_size() + num_heads = model_config.get_num_kv_heads(parallel_config) + num_attention_layers = model_config.get_num_layers_by_block_type( + parallel_config, LayerBlockType.attention) + + key_cache_block = cache_config.block_size * num_heads * head_size + value_cache_block = key_cache_block + total = num_attention_layers * (key_cache_block + value_cache_block) + if cache_config.cache_dtype == "auto": + dtype = model_config.dtype + else: + dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype] + dtype_size = get_dtype_size(dtype) + return dtype_size * total