fix capture model

This commit is contained in:
Alexander Matveev 2025-02-06 20:18:30 +00:00
parent 996b92ccb4
commit 5fb9dbe6f6
3 changed files with 84 additions and 47 deletions

View File

@ -13,13 +13,11 @@ ray[default]
# Install torch_xla # Install torch_xla
--pre --pre
--extra-index-url https://download.pytorch.org/whl/nightly/cpu --extra-index-url https://download.pytorch.org/whl/nightly/cpu
--find-links https://storage.googleapis.com/libtpu-wheels/index.html
--find-links https://storage.googleapis.com/libtpu-releases/index.html --find-links https://storage.googleapis.com/libtpu-releases/index.html
--find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html --find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
--find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html --find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
torch==2.6.0.dev20241126+cpu torch==2.6.0.dev20241216+cpu
torchvision==0.20.0.dev20241126+cpu torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250124-cp39-cp39-linux_x86_64.whl ; python_version == "3.9"
torch_xla[tpu] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0.dev20241126-cp39-cp39-linux_x86_64.whl ; python_version == "3.9" torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250124-cp310-cp310-linux_x86_64.whl ; python_version == "3.10"
torch_xla[tpu] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0.dev20241126-cp310-cp310-linux_x86_64.whl ; python_version == "3.10" torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250124-cp311-cp311-linux_x86_64.whl ; python_version == "3.11"
torch_xla[tpu] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0.dev20241126-cp311-cp311-linux_x86_64.whl ; python_version == "3.11"
jaxlib==0.4.36.dev20241122
jax==0.4.36.dev20241122

View File

@ -24,6 +24,7 @@ from vllm.v1.utils import bind_kv_cache
from vllm.v1.worker.gpu_input_batch import (CachedRequestState, InputBatch, from vllm.v1.worker.gpu_input_batch import (CachedRequestState, InputBatch,
ensure_decodes_first) ensure_decodes_first)
from vllm.v1.worker.model_runner_base import ExecutionMode, ModelRunnerBase from vllm.v1.worker.model_runner_base import ExecutionMode, ModelRunnerBase
from vllm.v1.core.kv_cache_utils import get_kv_cache_config
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.v1.core.scheduler import SchedulerOutput from vllm.v1.core.scheduler import SchedulerOutput
@ -688,62 +689,88 @@ class TPUModelRunner(ModelRunnerBase):
torch._dynamo.mark_dynamic(attn_metadata.context_lens, 0) torch._dynamo.mark_dynamic(attn_metadata.context_lens, 0)
torch._dynamo.mark_dynamic(attn_metadata.block_tables, 0) torch._dynamo.mark_dynamic(attn_metadata.block_tables, 0)
# TODO: Remove the attn_metadata above with set_forward_context(attn_metadata, self.vllm_config, 0):
with set_forward_context(None, self.vllm_config):
assert self.model is not None assert self.model is not None
self.model(token_ids, position_ids, None, kv_caches) self.model(token_ids, position_ids, attn_metadata, kv_caches)
def capture_model(self) -> None: def capture_model(self) -> None:
"""Compile the model.""" """Compile the model."""
logger.info("Compiling the model with different input shapes.") # Prefill
logger.info(
# Capture prefill shapes "Compiling the model with different input shapes for prefill:")
start = time.perf_counter() start = time.time()
for batch_size in [1]: for batch_size in [1]:
seq_len = 16 seq_len = 16
while True: while seq_len <= self.model_config.max_model_len:
self.dummy_run(self.kv_caches, batch_size, seq_len, self.dummy_run(self.kv_caches,
ExecutionMode.PREFILL) batch_size,
seq_len,
exec_mode=ExecutionMode.PREFILL)
xm.wait_device_ops() xm.wait_device_ops()
logger.info(" -- batch_size: %d, seq_len: %d", batch_size, logger.info(" batch_size: %d, seq_len: %d", batch_size,
seq_len) seq_len)
if seq_len >= self.model_config.max_model_len:
break
num_tokens = batch_size * seq_len num_tokens = batch_size * seq_len
if num_tokens >= self.scheduler_config.max_num_batched_tokens: if num_tokens >= self.scheduler_config.max_num_batched_tokens:
break break
# Move to next seq_len
seq_len = seq_len * 2 seq_len = seq_len * 2
end = time.perf_counter() end = time.time()
logger.info("Compilation for prefill shapes is done in %.2f [secs].", logger.info(" -- Compilation for prefill done in %.2f [secs].",
end - start) end - start)
# Capture decode shapes. # Prefix prefill
if self.scheduler_config.enable_chunked_prefill:
logger.info("Compiling the model with different input shapes for "
"prefix prefill:")
start = time.time()
for batch_size in [1]:
seq_len = 16
while seq_len <= self.model_config.max_model_len:
self.dummy_run(self.kv_caches,
batch_size,
seq_len,
exec_mode=ExecutionMode.PREFIX_PREFILL)
xm.wait_device_ops()
logger.info(" batch_size: %d, seq_len: %d", batch_size,
seq_len)
num_tokens = batch_size * seq_len
if (num_tokens
>= self.scheduler_config.max_num_batched_tokens):
break
seq_len = seq_len * 2
end = time.time()
logger.info(
" -- Compilation for prefix prefill done in %.2f [secs].",
end - start)
# Decode
logger.info(
"Compiling the model with different input shapes for decode:")
start = time.time() start = time.time()
seq_len = 1 seq_len = 1
batch_size = 8 # Must be in sync with _get_padded_batch_size() batch_size = 8 # Must be in sync with _get_padded_batch_size()
while True: while True:
self.dummy_run(self.kv_caches, batch_size, seq_len, self.dummy_run(self.kv_caches,
ExecutionMode.DECODE) batch_size,
seq_len,
exec_mode=ExecutionMode.DECODE)
xm.wait_device_ops() xm.wait_device_ops()
logger.info(" -- batch_size: %d, seq_len: %d, max_num_seqs = %d", logger.info(" batch_size: %d, seq_len: %d", batch_size, seq_len)
batch_size, seq_len,
self.scheduler_config.max_num_seqs)
if batch_size >= self.scheduler_config.max_num_seqs: if batch_size >= self.scheduler_config.max_num_seqs:
break break
batch_size = batch_size + 16 if batch_size >= 16 else batch_size * 2 batch_size = batch_size + 16 if batch_size >= 16 else batch_size * 2
end = time.time() end = time.time()
logger.info("Compilation for decode shapes is done in %.2f [secs].", logger.info(" -- Compilation for decode done in %.2f [secs].",
end - start) end - start)
def _initialize_kv_cache(self):
kv_cache_spec = self.get_kv_cache_spec()
kv_cache_config = get_kv_cache_config(vllm_config, kv_cache_spec,
availble_gpu_memory)
def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
""" """
Initialize KV cache based on `kv_cache_config`. Initialize KV cache based on `kv_cache_config`.
@ -810,7 +837,7 @@ class ModelWrapperV1(nn.Module):
memory profiling at initialization. memory profiling at initialization.
""" """
# Skip this in memory profiling at initialization. # Skip this in memory profiling at initialization.
if attn_metadata is not None: if attn_metadata is not None and kv_caches[0][0].numel() > 0:
# index_copy_(slot_mapping) only works when the inserted dimension # index_copy_(slot_mapping) only works when the inserted dimension
# is 0. However, the KV cache in the Pallas backend has the shape # is 0. However, the KV cache in the Pallas backend has the shape
# [num_kv_heads, num_blocks, block_size, head_size]. To make it # [num_kv_heads, num_blocks, block_size, head_size]. To make it

View File

@ -1,6 +1,6 @@
"""A TPU worker class.""" """A TPU worker class."""
import os import os
from typing import Optional from typing import Optional, Dict
import torch import torch
import torch.distributed import torch.distributed
@ -13,10 +13,13 @@ from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment) init_distributed_environment)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor import set_random_seed from vllm.model_executor import set_random_seed
from vllm.v1.kv_cache_interface import FullAttentionSpec
from vllm.v1.attention.backends.pallas import PallasAttentionBackend
from vllm.v1.core.scheduler import SchedulerOutput from vllm.v1.core.scheduler import SchedulerOutput
from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.worker.tpu_model_runner import ExecutionMode, TPUModelRunner from vllm.v1.worker.tpu_model_runner import ExecutionMode, TPUModelRunner
from vllm.v1.worker.worker_base import WorkerBase from vllm.v1.worker.worker_base import WorkerBase
from vllm.v1.utils import bind_kv_cache
logger = init_logger(__name__) logger = init_logger(__name__)
@ -74,20 +77,29 @@ class TPUWorker(WorkerBase):
def determine_available_memory(self) -> int: def determine_available_memory(self) -> int:
assert self.model_runner is not None assert self.model_runner is not None
num_layers = self.model_config.get_num_layers(self.parallel_config) kv_caches: Dict[str, torch.Tensor] = {}
kv_cache_spec = self.model_runner.get_kv_cache_spec()
for layer_name, layer_spec in kv_cache_spec.items():
if isinstance(layer_spec, FullAttentionSpec):
dtype = layer_spec.dtype
# use an empty tensor instead of `None`` to force Dynamo to pass # Use an empty tensor instead of `None`` to force Dynamo to pass
# it by reference, rather by specializing on the value ``None``. # it by reference, rather by specializing on the value ``None``.
# the `dtype` argument does not matter, and we use `float32` as tpu_k_cache = torch.tensor([], dtype=dtype, device=self.device)
# a placeholder (it has wide hardware support). tpu_v_cache = torch.tensor([], dtype=dtype, device=self.device)
kv_caches = [(torch.tensor([], dtype=torch.float32,
device=self.device), kv_caches[layer_name] = (tpu_k_cache, tpu_v_cache)
torch.tensor([], dtype=torch.float32, else:
device=self.device)) raise NotImplementedError
for _ in range(num_layers)]
runner_kv_caches = []
bind_kv_cache(
kv_caches,
self.vllm_config.compilation_config.static_forward_context,
runner_kv_caches)
self.model_runner.dummy_run( self.model_runner.dummy_run(
kv_caches, runner_kv_caches,
num_tokens=1, num_tokens=1,
seq_len=self.scheduler_config.max_num_batched_tokens, seq_len=self.scheduler_config.max_num_batched_tokens,
exec_mode=ExecutionMode.PREFILL, exec_mode=ExecutionMode.PREFILL,