From 5fb9dbe6f635ac05d939726f83bdcfba2c036253 Mon Sep 17 00:00:00 2001 From: Alexander Matveev Date: Thu, 6 Feb 2025 20:18:30 +0000 Subject: [PATCH] fix capture model --- requirements-tpu.txt | 12 ++--- vllm/v1/worker/tpu_model_runner.py | 83 ++++++++++++++++++++---------- vllm/v1/worker/tpu_worker.py | 36 ++++++++----- 3 files changed, 84 insertions(+), 47 deletions(-) diff --git a/requirements-tpu.txt b/requirements-tpu.txt index 8ab18b3770ae8..f438b5f4c8f68 100644 --- a/requirements-tpu.txt +++ b/requirements-tpu.txt @@ -13,13 +13,11 @@ ray[default] # Install torch_xla --pre --extra-index-url https://download.pytorch.org/whl/nightly/cpu +--find-links https://storage.googleapis.com/libtpu-wheels/index.html --find-links https://storage.googleapis.com/libtpu-releases/index.html --find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html --find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html -torch==2.6.0.dev20241126+cpu -torchvision==0.20.0.dev20241126+cpu -torch_xla[tpu] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0.dev20241126-cp39-cp39-linux_x86_64.whl ; python_version == "3.9" -torch_xla[tpu] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0.dev20241126-cp310-cp310-linux_x86_64.whl ; python_version == "3.10" -torch_xla[tpu] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0.dev20241126-cp311-cp311-linux_x86_64.whl ; python_version == "3.11" -jaxlib==0.4.36.dev20241122 -jax==0.4.36.dev20241122 +torch==2.6.0.dev20241216+cpu +torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250124-cp39-cp39-linux_x86_64.whl ; python_version == "3.9" +torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250124-cp310-cp310-linux_x86_64.whl ; python_version == "3.10" +torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250124-cp311-cp311-linux_x86_64.whl ; python_version == "3.11" \ No newline at end of file diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 31301ff0e21a7..568e2f0cef4d8 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -24,6 +24,7 @@ from vllm.v1.utils import bind_kv_cache from vllm.v1.worker.gpu_input_batch import (CachedRequestState, InputBatch, ensure_decodes_first) from vllm.v1.worker.model_runner_base import ExecutionMode, ModelRunnerBase +from vllm.v1.core.kv_cache_utils import get_kv_cache_config if TYPE_CHECKING: from vllm.v1.core.scheduler import SchedulerOutput @@ -688,62 +689,88 @@ class TPUModelRunner(ModelRunnerBase): torch._dynamo.mark_dynamic(attn_metadata.context_lens, 0) torch._dynamo.mark_dynamic(attn_metadata.block_tables, 0) - # TODO: Remove the attn_metadata above - with set_forward_context(None, self.vllm_config): + with set_forward_context(attn_metadata, self.vllm_config, 0): assert self.model is not None - self.model(token_ids, position_ids, None, kv_caches) + self.model(token_ids, position_ids, attn_metadata, kv_caches) def capture_model(self) -> None: """Compile the model.""" - logger.info("Compiling the model with different input shapes.") - - # Capture prefill shapes - start = time.perf_counter() + # Prefill + logger.info( + "Compiling the model with different input shapes for prefill:") + start = time.time() for batch_size in [1]: seq_len = 16 - while True: - self.dummy_run(self.kv_caches, batch_size, seq_len, - ExecutionMode.PREFILL) + while seq_len <= self.model_config.max_model_len: + self.dummy_run(self.kv_caches, + batch_size, + seq_len, + exec_mode=ExecutionMode.PREFILL) xm.wait_device_ops() - logger.info(" -- batch_size: %d, seq_len: %d", batch_size, + logger.info(" batch_size: %d, seq_len: %d", batch_size, seq_len) - - if seq_len >= self.model_config.max_model_len: - break - num_tokens = batch_size * seq_len if num_tokens >= self.scheduler_config.max_num_batched_tokens: break - - # Move to next seq_len seq_len = seq_len * 2 - end = time.perf_counter() - logger.info("Compilation for prefill shapes is done in %.2f [secs].", + end = time.time() + logger.info(" -- Compilation for prefill done in %.2f [secs].", end - start) - # Capture decode shapes. + # Prefix prefill + if self.scheduler_config.enable_chunked_prefill: + logger.info("Compiling the model with different input shapes for " + "prefix prefill:") + start = time.time() + for batch_size in [1]: + seq_len = 16 + while seq_len <= self.model_config.max_model_len: + self.dummy_run(self.kv_caches, + batch_size, + seq_len, + exec_mode=ExecutionMode.PREFIX_PREFILL) + xm.wait_device_ops() + logger.info(" batch_size: %d, seq_len: %d", batch_size, + seq_len) + num_tokens = batch_size * seq_len + if (num_tokens + >= self.scheduler_config.max_num_batched_tokens): + break + seq_len = seq_len * 2 + end = time.time() + logger.info( + " -- Compilation for prefix prefill done in %.2f [secs].", + end - start) + + # Decode + logger.info( + "Compiling the model with different input shapes for decode:") start = time.time() seq_len = 1 batch_size = 8 # Must be in sync with _get_padded_batch_size() while True: - self.dummy_run(self.kv_caches, batch_size, seq_len, - ExecutionMode.DECODE) + self.dummy_run(self.kv_caches, + batch_size, + seq_len, + exec_mode=ExecutionMode.DECODE) xm.wait_device_ops() - logger.info(" -- batch_size: %d, seq_len: %d, max_num_seqs = %d", - batch_size, seq_len, - self.scheduler_config.max_num_seqs) + logger.info(" batch_size: %d, seq_len: %d", batch_size, seq_len) if batch_size >= self.scheduler_config.max_num_seqs: break - batch_size = batch_size + 16 if batch_size >= 16 else batch_size * 2 end = time.time() - logger.info("Compilation for decode shapes is done in %.2f [secs].", + logger.info(" -- Compilation for decode done in %.2f [secs].", end - start) + def _initialize_kv_cache(self): + kv_cache_spec = self.get_kv_cache_spec() + + kv_cache_config = get_kv_cache_config(vllm_config, kv_cache_spec, + availble_gpu_memory) def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: """ Initialize KV cache based on `kv_cache_config`. @@ -810,7 +837,7 @@ class ModelWrapperV1(nn.Module): memory profiling at initialization. """ # Skip this in memory profiling at initialization. - if attn_metadata is not None: + if attn_metadata is not None and kv_caches[0][0].numel() > 0: # index_copy_(slot_mapping) only works when the inserted dimension # is 0. However, the KV cache in the Pallas backend has the shape # [num_kv_heads, num_blocks, block_size, head_size]. To make it diff --git a/vllm/v1/worker/tpu_worker.py b/vllm/v1/worker/tpu_worker.py index 7d76c82bd7c63..d277c75a2daab 100644 --- a/vllm/v1/worker/tpu_worker.py +++ b/vllm/v1/worker/tpu_worker.py @@ -1,6 +1,6 @@ """A TPU worker class.""" import os -from typing import Optional +from typing import Optional, Dict import torch import torch.distributed @@ -13,10 +13,13 @@ from vllm.distributed import (ensure_model_parallel_initialized, init_distributed_environment) from vllm.logger import init_logger from vllm.model_executor import set_random_seed +from vllm.v1.kv_cache_interface import FullAttentionSpec +from vllm.v1.attention.backends.pallas import PallasAttentionBackend from vllm.v1.core.scheduler import SchedulerOutput from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.worker.tpu_model_runner import ExecutionMode, TPUModelRunner from vllm.v1.worker.worker_base import WorkerBase +from vllm.v1.utils import bind_kv_cache logger = init_logger(__name__) @@ -74,20 +77,29 @@ class TPUWorker(WorkerBase): def determine_available_memory(self) -> int: assert self.model_runner is not None - num_layers = self.model_config.get_num_layers(self.parallel_config) + kv_caches: Dict[str, torch.Tensor] = {} + kv_cache_spec = self.model_runner.get_kv_cache_spec() + for layer_name, layer_spec in kv_cache_spec.items(): + if isinstance(layer_spec, FullAttentionSpec): + dtype = layer_spec.dtype - # use an empty tensor instead of `None`` to force Dynamo to pass - # it by reference, rather by specializing on the value ``None``. - # the `dtype` argument does not matter, and we use `float32` as - # a placeholder (it has wide hardware support). - kv_caches = [(torch.tensor([], dtype=torch.float32, - device=self.device), - torch.tensor([], dtype=torch.float32, - device=self.device)) - for _ in range(num_layers)] + # Use an empty tensor instead of `None`` to force Dynamo to pass + # it by reference, rather by specializing on the value ``None``. + tpu_k_cache = torch.tensor([], dtype=dtype, device=self.device) + tpu_v_cache = torch.tensor([], dtype=dtype, device=self.device) + + kv_caches[layer_name] = (tpu_k_cache, tpu_v_cache) + else: + raise NotImplementedError + + runner_kv_caches = [] + bind_kv_cache( + kv_caches, + self.vllm_config.compilation_config.static_forward_context, + runner_kv_caches) self.model_runner.dummy_run( - kv_caches, + runner_kv_caches, num_tokens=1, seq_len=self.scheduler_config.max_num_batched_tokens, exec_mode=ExecutionMode.PREFILL,