mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-16 13:27:19 +08:00
fix capture model
This commit is contained in:
parent
996b92ccb4
commit
5fb9dbe6f6
@ -13,13 +13,11 @@ ray[default]
|
||||
# Install torch_xla
|
||||
--pre
|
||||
--extra-index-url https://download.pytorch.org/whl/nightly/cpu
|
||||
--find-links https://storage.googleapis.com/libtpu-wheels/index.html
|
||||
--find-links https://storage.googleapis.com/libtpu-releases/index.html
|
||||
--find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
|
||||
--find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
|
||||
torch==2.6.0.dev20241126+cpu
|
||||
torchvision==0.20.0.dev20241126+cpu
|
||||
torch_xla[tpu] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0.dev20241126-cp39-cp39-linux_x86_64.whl ; python_version == "3.9"
|
||||
torch_xla[tpu] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0.dev20241126-cp310-cp310-linux_x86_64.whl ; python_version == "3.10"
|
||||
torch_xla[tpu] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0.dev20241126-cp311-cp311-linux_x86_64.whl ; python_version == "3.11"
|
||||
jaxlib==0.4.36.dev20241122
|
||||
jax==0.4.36.dev20241122
|
||||
torch==2.6.0.dev20241216+cpu
|
||||
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250124-cp39-cp39-linux_x86_64.whl ; python_version == "3.9"
|
||||
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250124-cp310-cp310-linux_x86_64.whl ; python_version == "3.10"
|
||||
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250124-cp311-cp311-linux_x86_64.whl ; python_version == "3.11"
|
||||
@ -24,6 +24,7 @@ from vllm.v1.utils import bind_kv_cache
|
||||
from vllm.v1.worker.gpu_input_batch import (CachedRequestState, InputBatch,
|
||||
ensure_decodes_first)
|
||||
from vllm.v1.worker.model_runner_base import ExecutionMode, ModelRunnerBase
|
||||
from vllm.v1.core.kv_cache_utils import get_kv_cache_config
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.core.scheduler import SchedulerOutput
|
||||
@ -688,62 +689,88 @@ class TPUModelRunner(ModelRunnerBase):
|
||||
torch._dynamo.mark_dynamic(attn_metadata.context_lens, 0)
|
||||
torch._dynamo.mark_dynamic(attn_metadata.block_tables, 0)
|
||||
|
||||
# TODO: Remove the attn_metadata above
|
||||
with set_forward_context(None, self.vllm_config):
|
||||
with set_forward_context(attn_metadata, self.vllm_config, 0):
|
||||
assert self.model is not None
|
||||
self.model(token_ids, position_ids, None, kv_caches)
|
||||
self.model(token_ids, position_ids, attn_metadata, kv_caches)
|
||||
|
||||
def capture_model(self) -> None:
|
||||
"""Compile the model."""
|
||||
|
||||
logger.info("Compiling the model with different input shapes.")
|
||||
|
||||
# Capture prefill shapes
|
||||
start = time.perf_counter()
|
||||
# Prefill
|
||||
logger.info(
|
||||
"Compiling the model with different input shapes for prefill:")
|
||||
start = time.time()
|
||||
for batch_size in [1]:
|
||||
seq_len = 16
|
||||
while True:
|
||||
self.dummy_run(self.kv_caches, batch_size, seq_len,
|
||||
ExecutionMode.PREFILL)
|
||||
while seq_len <= self.model_config.max_model_len:
|
||||
self.dummy_run(self.kv_caches,
|
||||
batch_size,
|
||||
seq_len,
|
||||
exec_mode=ExecutionMode.PREFILL)
|
||||
xm.wait_device_ops()
|
||||
logger.info(" -- batch_size: %d, seq_len: %d", batch_size,
|
||||
logger.info(" batch_size: %d, seq_len: %d", batch_size,
|
||||
seq_len)
|
||||
|
||||
if seq_len >= self.model_config.max_model_len:
|
||||
break
|
||||
|
||||
num_tokens = batch_size * seq_len
|
||||
if num_tokens >= self.scheduler_config.max_num_batched_tokens:
|
||||
break
|
||||
|
||||
# Move to next seq_len
|
||||
seq_len = seq_len * 2
|
||||
|
||||
end = time.perf_counter()
|
||||
logger.info("Compilation for prefill shapes is done in %.2f [secs].",
|
||||
end = time.time()
|
||||
logger.info(" -- Compilation for prefill done in %.2f [secs].",
|
||||
end - start)
|
||||
|
||||
# Capture decode shapes.
|
||||
# Prefix prefill
|
||||
if self.scheduler_config.enable_chunked_prefill:
|
||||
logger.info("Compiling the model with different input shapes for "
|
||||
"prefix prefill:")
|
||||
start = time.time()
|
||||
for batch_size in [1]:
|
||||
seq_len = 16
|
||||
while seq_len <= self.model_config.max_model_len:
|
||||
self.dummy_run(self.kv_caches,
|
||||
batch_size,
|
||||
seq_len,
|
||||
exec_mode=ExecutionMode.PREFIX_PREFILL)
|
||||
xm.wait_device_ops()
|
||||
logger.info(" batch_size: %d, seq_len: %d", batch_size,
|
||||
seq_len)
|
||||
num_tokens = batch_size * seq_len
|
||||
if (num_tokens
|
||||
>= self.scheduler_config.max_num_batched_tokens):
|
||||
break
|
||||
seq_len = seq_len * 2
|
||||
end = time.time()
|
||||
logger.info(
|
||||
" -- Compilation for prefix prefill done in %.2f [secs].",
|
||||
end - start)
|
||||
|
||||
# Decode
|
||||
logger.info(
|
||||
"Compiling the model with different input shapes for decode:")
|
||||
start = time.time()
|
||||
seq_len = 1
|
||||
batch_size = 8 # Must be in sync with _get_padded_batch_size()
|
||||
while True:
|
||||
self.dummy_run(self.kv_caches, batch_size, seq_len,
|
||||
ExecutionMode.DECODE)
|
||||
self.dummy_run(self.kv_caches,
|
||||
batch_size,
|
||||
seq_len,
|
||||
exec_mode=ExecutionMode.DECODE)
|
||||
xm.wait_device_ops()
|
||||
logger.info(" -- batch_size: %d, seq_len: %d, max_num_seqs = %d",
|
||||
batch_size, seq_len,
|
||||
self.scheduler_config.max_num_seqs)
|
||||
logger.info(" batch_size: %d, seq_len: %d", batch_size, seq_len)
|
||||
|
||||
if batch_size >= self.scheduler_config.max_num_seqs:
|
||||
break
|
||||
|
||||
batch_size = batch_size + 16 if batch_size >= 16 else batch_size * 2
|
||||
|
||||
end = time.time()
|
||||
logger.info("Compilation for decode shapes is done in %.2f [secs].",
|
||||
logger.info(" -- Compilation for decode done in %.2f [secs].",
|
||||
end - start)
|
||||
|
||||
def _initialize_kv_cache(self):
|
||||
kv_cache_spec = self.get_kv_cache_spec()
|
||||
|
||||
kv_cache_config = get_kv_cache_config(vllm_config, kv_cache_spec,
|
||||
availble_gpu_memory)
|
||||
def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
|
||||
"""
|
||||
Initialize KV cache based on `kv_cache_config`.
|
||||
@ -810,7 +837,7 @@ class ModelWrapperV1(nn.Module):
|
||||
memory profiling at initialization.
|
||||
"""
|
||||
# Skip this in memory profiling at initialization.
|
||||
if attn_metadata is not None:
|
||||
if attn_metadata is not None and kv_caches[0][0].numel() > 0:
|
||||
# index_copy_(slot_mapping) only works when the inserted dimension
|
||||
# is 0. However, the KV cache in the Pallas backend has the shape
|
||||
# [num_kv_heads, num_blocks, block_size, head_size]. To make it
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
"""A TPU worker class."""
|
||||
import os
|
||||
from typing import Optional
|
||||
from typing import Optional, Dict
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
@ -13,10 +13,13 @@ from vllm.distributed import (ensure_model_parallel_initialized,
|
||||
init_distributed_environment)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor import set_random_seed
|
||||
from vllm.v1.kv_cache_interface import FullAttentionSpec
|
||||
from vllm.v1.attention.backends.pallas import PallasAttentionBackend
|
||||
from vllm.v1.core.scheduler import SchedulerOutput
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
from vllm.v1.worker.tpu_model_runner import ExecutionMode, TPUModelRunner
|
||||
from vllm.v1.worker.worker_base import WorkerBase
|
||||
from vllm.v1.utils import bind_kv_cache
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -74,20 +77,29 @@ class TPUWorker(WorkerBase):
|
||||
def determine_available_memory(self) -> int:
|
||||
assert self.model_runner is not None
|
||||
|
||||
num_layers = self.model_config.get_num_layers(self.parallel_config)
|
||||
kv_caches: Dict[str, torch.Tensor] = {}
|
||||
kv_cache_spec = self.model_runner.get_kv_cache_spec()
|
||||
for layer_name, layer_spec in kv_cache_spec.items():
|
||||
if isinstance(layer_spec, FullAttentionSpec):
|
||||
dtype = layer_spec.dtype
|
||||
|
||||
# use an empty tensor instead of `None`` to force Dynamo to pass
|
||||
# it by reference, rather by specializing on the value ``None``.
|
||||
# the `dtype` argument does not matter, and we use `float32` as
|
||||
# a placeholder (it has wide hardware support).
|
||||
kv_caches = [(torch.tensor([], dtype=torch.float32,
|
||||
device=self.device),
|
||||
torch.tensor([], dtype=torch.float32,
|
||||
device=self.device))
|
||||
for _ in range(num_layers)]
|
||||
# Use an empty tensor instead of `None`` to force Dynamo to pass
|
||||
# it by reference, rather by specializing on the value ``None``.
|
||||
tpu_k_cache = torch.tensor([], dtype=dtype, device=self.device)
|
||||
tpu_v_cache = torch.tensor([], dtype=dtype, device=self.device)
|
||||
|
||||
kv_caches[layer_name] = (tpu_k_cache, tpu_v_cache)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
runner_kv_caches = []
|
||||
bind_kv_cache(
|
||||
kv_caches,
|
||||
self.vllm_config.compilation_config.static_forward_context,
|
||||
runner_kv_caches)
|
||||
|
||||
self.model_runner.dummy_run(
|
||||
kv_caches,
|
||||
runner_kv_caches,
|
||||
num_tokens=1,
|
||||
seq_len=self.scheduler_config.max_num_batched_tokens,
|
||||
exec_mode=ExecutionMode.PREFILL,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user