fix capture model

This commit is contained in:
Alexander Matveev 2025-02-06 20:18:30 +00:00
parent 996b92ccb4
commit 5fb9dbe6f6
3 changed files with 84 additions and 47 deletions

View File

@ -13,13 +13,11 @@ ray[default]
# Install torch_xla
--pre
--extra-index-url https://download.pytorch.org/whl/nightly/cpu
--find-links https://storage.googleapis.com/libtpu-wheels/index.html
--find-links https://storage.googleapis.com/libtpu-releases/index.html
--find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
--find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
torch==2.6.0.dev20241126+cpu
torchvision==0.20.0.dev20241126+cpu
torch_xla[tpu] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0.dev20241126-cp39-cp39-linux_x86_64.whl ; python_version == "3.9"
torch_xla[tpu] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0.dev20241126-cp310-cp310-linux_x86_64.whl ; python_version == "3.10"
torch_xla[tpu] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0.dev20241126-cp311-cp311-linux_x86_64.whl ; python_version == "3.11"
jaxlib==0.4.36.dev20241122
jax==0.4.36.dev20241122
torch==2.6.0.dev20241216+cpu
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250124-cp39-cp39-linux_x86_64.whl ; python_version == "3.9"
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250124-cp310-cp310-linux_x86_64.whl ; python_version == "3.10"
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250124-cp311-cp311-linux_x86_64.whl ; python_version == "3.11"

View File

@ -24,6 +24,7 @@ from vllm.v1.utils import bind_kv_cache
from vllm.v1.worker.gpu_input_batch import (CachedRequestState, InputBatch,
ensure_decodes_first)
from vllm.v1.worker.model_runner_base import ExecutionMode, ModelRunnerBase
from vllm.v1.core.kv_cache_utils import get_kv_cache_config
if TYPE_CHECKING:
from vllm.v1.core.scheduler import SchedulerOutput
@ -688,62 +689,88 @@ class TPUModelRunner(ModelRunnerBase):
torch._dynamo.mark_dynamic(attn_metadata.context_lens, 0)
torch._dynamo.mark_dynamic(attn_metadata.block_tables, 0)
# TODO: Remove the attn_metadata above
with set_forward_context(None, self.vllm_config):
with set_forward_context(attn_metadata, self.vllm_config, 0):
assert self.model is not None
self.model(token_ids, position_ids, None, kv_caches)
self.model(token_ids, position_ids, attn_metadata, kv_caches)
def capture_model(self) -> None:
"""Compile the model."""
logger.info("Compiling the model with different input shapes.")
# Capture prefill shapes
start = time.perf_counter()
# Prefill
logger.info(
"Compiling the model with different input shapes for prefill:")
start = time.time()
for batch_size in [1]:
seq_len = 16
while True:
self.dummy_run(self.kv_caches, batch_size, seq_len,
ExecutionMode.PREFILL)
while seq_len <= self.model_config.max_model_len:
self.dummy_run(self.kv_caches,
batch_size,
seq_len,
exec_mode=ExecutionMode.PREFILL)
xm.wait_device_ops()
logger.info(" -- batch_size: %d, seq_len: %d", batch_size,
logger.info(" batch_size: %d, seq_len: %d", batch_size,
seq_len)
if seq_len >= self.model_config.max_model_len:
break
num_tokens = batch_size * seq_len
if num_tokens >= self.scheduler_config.max_num_batched_tokens:
break
# Move to next seq_len
seq_len = seq_len * 2
end = time.perf_counter()
logger.info("Compilation for prefill shapes is done in %.2f [secs].",
end = time.time()
logger.info(" -- Compilation for prefill done in %.2f [secs].",
end - start)
# Capture decode shapes.
# Prefix prefill
if self.scheduler_config.enable_chunked_prefill:
logger.info("Compiling the model with different input shapes for "
"prefix prefill:")
start = time.time()
for batch_size in [1]:
seq_len = 16
while seq_len <= self.model_config.max_model_len:
self.dummy_run(self.kv_caches,
batch_size,
seq_len,
exec_mode=ExecutionMode.PREFIX_PREFILL)
xm.wait_device_ops()
logger.info(" batch_size: %d, seq_len: %d", batch_size,
seq_len)
num_tokens = batch_size * seq_len
if (num_tokens
>= self.scheduler_config.max_num_batched_tokens):
break
seq_len = seq_len * 2
end = time.time()
logger.info(
" -- Compilation for prefix prefill done in %.2f [secs].",
end - start)
# Decode
logger.info(
"Compiling the model with different input shapes for decode:")
start = time.time()
seq_len = 1
batch_size = 8 # Must be in sync with _get_padded_batch_size()
while True:
self.dummy_run(self.kv_caches, batch_size, seq_len,
ExecutionMode.DECODE)
self.dummy_run(self.kv_caches,
batch_size,
seq_len,
exec_mode=ExecutionMode.DECODE)
xm.wait_device_ops()
logger.info(" -- batch_size: %d, seq_len: %d, max_num_seqs = %d",
batch_size, seq_len,
self.scheduler_config.max_num_seqs)
logger.info(" batch_size: %d, seq_len: %d", batch_size, seq_len)
if batch_size >= self.scheduler_config.max_num_seqs:
break
batch_size = batch_size + 16 if batch_size >= 16 else batch_size * 2
end = time.time()
logger.info("Compilation for decode shapes is done in %.2f [secs].",
logger.info(" -- Compilation for decode done in %.2f [secs].",
end - start)
def _initialize_kv_cache(self):
kv_cache_spec = self.get_kv_cache_spec()
kv_cache_config = get_kv_cache_config(vllm_config, kv_cache_spec,
availble_gpu_memory)
def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
"""
Initialize KV cache based on `kv_cache_config`.
@ -810,7 +837,7 @@ class ModelWrapperV1(nn.Module):
memory profiling at initialization.
"""
# Skip this in memory profiling at initialization.
if attn_metadata is not None:
if attn_metadata is not None and kv_caches[0][0].numel() > 0:
# index_copy_(slot_mapping) only works when the inserted dimension
# is 0. However, the KV cache in the Pallas backend has the shape
# [num_kv_heads, num_blocks, block_size, head_size]. To make it

View File

@ -1,6 +1,6 @@
"""A TPU worker class."""
import os
from typing import Optional
from typing import Optional, Dict
import torch
import torch.distributed
@ -13,10 +13,13 @@ from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment)
from vllm.logger import init_logger
from vllm.model_executor import set_random_seed
from vllm.v1.kv_cache_interface import FullAttentionSpec
from vllm.v1.attention.backends.pallas import PallasAttentionBackend
from vllm.v1.core.scheduler import SchedulerOutput
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.worker.tpu_model_runner import ExecutionMode, TPUModelRunner
from vllm.v1.worker.worker_base import WorkerBase
from vllm.v1.utils import bind_kv_cache
logger = init_logger(__name__)
@ -74,20 +77,29 @@ class TPUWorker(WorkerBase):
def determine_available_memory(self) -> int:
assert self.model_runner is not None
num_layers = self.model_config.get_num_layers(self.parallel_config)
kv_caches: Dict[str, torch.Tensor] = {}
kv_cache_spec = self.model_runner.get_kv_cache_spec()
for layer_name, layer_spec in kv_cache_spec.items():
if isinstance(layer_spec, FullAttentionSpec):
dtype = layer_spec.dtype
# use an empty tensor instead of `None`` to force Dynamo to pass
# it by reference, rather by specializing on the value ``None``.
# the `dtype` argument does not matter, and we use `float32` as
# a placeholder (it has wide hardware support).
kv_caches = [(torch.tensor([], dtype=torch.float32,
device=self.device),
torch.tensor([], dtype=torch.float32,
device=self.device))
for _ in range(num_layers)]
# Use an empty tensor instead of `None`` to force Dynamo to pass
# it by reference, rather by specializing on the value ``None``.
tpu_k_cache = torch.tensor([], dtype=dtype, device=self.device)
tpu_v_cache = torch.tensor([], dtype=dtype, device=self.device)
kv_caches[layer_name] = (tpu_k_cache, tpu_v_cache)
else:
raise NotImplementedError
runner_kv_caches = []
bind_kv_cache(
kv_caches,
self.vllm_config.compilation_config.static_forward_context,
runner_kv_caches)
self.model_runner.dummy_run(
kv_caches,
runner_kv_caches,
num_tokens=1,
seq_len=self.scheduler_config.max_num_batched_tokens,
exec_mode=ExecutionMode.PREFILL,