mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-24 21:24:34 +08:00
fix capture model
This commit is contained in:
parent
996b92ccb4
commit
5fb9dbe6f6
@ -13,13 +13,11 @@ ray[default]
|
|||||||
# Install torch_xla
|
# Install torch_xla
|
||||||
--pre
|
--pre
|
||||||
--extra-index-url https://download.pytorch.org/whl/nightly/cpu
|
--extra-index-url https://download.pytorch.org/whl/nightly/cpu
|
||||||
|
--find-links https://storage.googleapis.com/libtpu-wheels/index.html
|
||||||
--find-links https://storage.googleapis.com/libtpu-releases/index.html
|
--find-links https://storage.googleapis.com/libtpu-releases/index.html
|
||||||
--find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
|
--find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
|
||||||
--find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
|
--find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
|
||||||
torch==2.6.0.dev20241126+cpu
|
torch==2.6.0.dev20241216+cpu
|
||||||
torchvision==0.20.0.dev20241126+cpu
|
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250124-cp39-cp39-linux_x86_64.whl ; python_version == "3.9"
|
||||||
torch_xla[tpu] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0.dev20241126-cp39-cp39-linux_x86_64.whl ; python_version == "3.9"
|
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250124-cp310-cp310-linux_x86_64.whl ; python_version == "3.10"
|
||||||
torch_xla[tpu] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0.dev20241126-cp310-cp310-linux_x86_64.whl ; python_version == "3.10"
|
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250124-cp311-cp311-linux_x86_64.whl ; python_version == "3.11"
|
||||||
torch_xla[tpu] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0.dev20241126-cp311-cp311-linux_x86_64.whl ; python_version == "3.11"
|
|
||||||
jaxlib==0.4.36.dev20241122
|
|
||||||
jax==0.4.36.dev20241122
|
|
||||||
@ -24,6 +24,7 @@ from vllm.v1.utils import bind_kv_cache
|
|||||||
from vllm.v1.worker.gpu_input_batch import (CachedRequestState, InputBatch,
|
from vllm.v1.worker.gpu_input_batch import (CachedRequestState, InputBatch,
|
||||||
ensure_decodes_first)
|
ensure_decodes_first)
|
||||||
from vllm.v1.worker.model_runner_base import ExecutionMode, ModelRunnerBase
|
from vllm.v1.worker.model_runner_base import ExecutionMode, ModelRunnerBase
|
||||||
|
from vllm.v1.core.kv_cache_utils import get_kv_cache_config
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.v1.core.scheduler import SchedulerOutput
|
from vllm.v1.core.scheduler import SchedulerOutput
|
||||||
@ -688,62 +689,88 @@ class TPUModelRunner(ModelRunnerBase):
|
|||||||
torch._dynamo.mark_dynamic(attn_metadata.context_lens, 0)
|
torch._dynamo.mark_dynamic(attn_metadata.context_lens, 0)
|
||||||
torch._dynamo.mark_dynamic(attn_metadata.block_tables, 0)
|
torch._dynamo.mark_dynamic(attn_metadata.block_tables, 0)
|
||||||
|
|
||||||
# TODO: Remove the attn_metadata above
|
with set_forward_context(attn_metadata, self.vllm_config, 0):
|
||||||
with set_forward_context(None, self.vllm_config):
|
|
||||||
assert self.model is not None
|
assert self.model is not None
|
||||||
self.model(token_ids, position_ids, None, kv_caches)
|
self.model(token_ids, position_ids, attn_metadata, kv_caches)
|
||||||
|
|
||||||
def capture_model(self) -> None:
|
def capture_model(self) -> None:
|
||||||
"""Compile the model."""
|
"""Compile the model."""
|
||||||
|
|
||||||
logger.info("Compiling the model with different input shapes.")
|
# Prefill
|
||||||
|
logger.info(
|
||||||
# Capture prefill shapes
|
"Compiling the model with different input shapes for prefill:")
|
||||||
start = time.perf_counter()
|
start = time.time()
|
||||||
for batch_size in [1]:
|
for batch_size in [1]:
|
||||||
seq_len = 16
|
seq_len = 16
|
||||||
while True:
|
while seq_len <= self.model_config.max_model_len:
|
||||||
self.dummy_run(self.kv_caches, batch_size, seq_len,
|
self.dummy_run(self.kv_caches,
|
||||||
ExecutionMode.PREFILL)
|
batch_size,
|
||||||
|
seq_len,
|
||||||
|
exec_mode=ExecutionMode.PREFILL)
|
||||||
xm.wait_device_ops()
|
xm.wait_device_ops()
|
||||||
logger.info(" -- batch_size: %d, seq_len: %d", batch_size,
|
logger.info(" batch_size: %d, seq_len: %d", batch_size,
|
||||||
seq_len)
|
seq_len)
|
||||||
|
|
||||||
if seq_len >= self.model_config.max_model_len:
|
|
||||||
break
|
|
||||||
|
|
||||||
num_tokens = batch_size * seq_len
|
num_tokens = batch_size * seq_len
|
||||||
if num_tokens >= self.scheduler_config.max_num_batched_tokens:
|
if num_tokens >= self.scheduler_config.max_num_batched_tokens:
|
||||||
break
|
break
|
||||||
|
|
||||||
# Move to next seq_len
|
|
||||||
seq_len = seq_len * 2
|
seq_len = seq_len * 2
|
||||||
|
|
||||||
end = time.perf_counter()
|
end = time.time()
|
||||||
logger.info("Compilation for prefill shapes is done in %.2f [secs].",
|
logger.info(" -- Compilation for prefill done in %.2f [secs].",
|
||||||
end - start)
|
end - start)
|
||||||
|
|
||||||
# Capture decode shapes.
|
# Prefix prefill
|
||||||
|
if self.scheduler_config.enable_chunked_prefill:
|
||||||
|
logger.info("Compiling the model with different input shapes for "
|
||||||
|
"prefix prefill:")
|
||||||
|
start = time.time()
|
||||||
|
for batch_size in [1]:
|
||||||
|
seq_len = 16
|
||||||
|
while seq_len <= self.model_config.max_model_len:
|
||||||
|
self.dummy_run(self.kv_caches,
|
||||||
|
batch_size,
|
||||||
|
seq_len,
|
||||||
|
exec_mode=ExecutionMode.PREFIX_PREFILL)
|
||||||
|
xm.wait_device_ops()
|
||||||
|
logger.info(" batch_size: %d, seq_len: %d", batch_size,
|
||||||
|
seq_len)
|
||||||
|
num_tokens = batch_size * seq_len
|
||||||
|
if (num_tokens
|
||||||
|
>= self.scheduler_config.max_num_batched_tokens):
|
||||||
|
break
|
||||||
|
seq_len = seq_len * 2
|
||||||
|
end = time.time()
|
||||||
|
logger.info(
|
||||||
|
" -- Compilation for prefix prefill done in %.2f [secs].",
|
||||||
|
end - start)
|
||||||
|
|
||||||
|
# Decode
|
||||||
|
logger.info(
|
||||||
|
"Compiling the model with different input shapes for decode:")
|
||||||
start = time.time()
|
start = time.time()
|
||||||
seq_len = 1
|
seq_len = 1
|
||||||
batch_size = 8 # Must be in sync with _get_padded_batch_size()
|
batch_size = 8 # Must be in sync with _get_padded_batch_size()
|
||||||
while True:
|
while True:
|
||||||
self.dummy_run(self.kv_caches, batch_size, seq_len,
|
self.dummy_run(self.kv_caches,
|
||||||
ExecutionMode.DECODE)
|
batch_size,
|
||||||
|
seq_len,
|
||||||
|
exec_mode=ExecutionMode.DECODE)
|
||||||
xm.wait_device_ops()
|
xm.wait_device_ops()
|
||||||
logger.info(" -- batch_size: %d, seq_len: %d, max_num_seqs = %d",
|
logger.info(" batch_size: %d, seq_len: %d", batch_size, seq_len)
|
||||||
batch_size, seq_len,
|
|
||||||
self.scheduler_config.max_num_seqs)
|
|
||||||
|
|
||||||
if batch_size >= self.scheduler_config.max_num_seqs:
|
if batch_size >= self.scheduler_config.max_num_seqs:
|
||||||
break
|
break
|
||||||
|
|
||||||
batch_size = batch_size + 16 if batch_size >= 16 else batch_size * 2
|
batch_size = batch_size + 16 if batch_size >= 16 else batch_size * 2
|
||||||
|
|
||||||
end = time.time()
|
end = time.time()
|
||||||
logger.info("Compilation for decode shapes is done in %.2f [secs].",
|
logger.info(" -- Compilation for decode done in %.2f [secs].",
|
||||||
end - start)
|
end - start)
|
||||||
|
|
||||||
|
def _initialize_kv_cache(self):
|
||||||
|
kv_cache_spec = self.get_kv_cache_spec()
|
||||||
|
|
||||||
|
kv_cache_config = get_kv_cache_config(vllm_config, kv_cache_spec,
|
||||||
|
availble_gpu_memory)
|
||||||
def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
|
def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
|
||||||
"""
|
"""
|
||||||
Initialize KV cache based on `kv_cache_config`.
|
Initialize KV cache based on `kv_cache_config`.
|
||||||
@ -810,7 +837,7 @@ class ModelWrapperV1(nn.Module):
|
|||||||
memory profiling at initialization.
|
memory profiling at initialization.
|
||||||
"""
|
"""
|
||||||
# Skip this in memory profiling at initialization.
|
# Skip this in memory profiling at initialization.
|
||||||
if attn_metadata is not None:
|
if attn_metadata is not None and kv_caches[0][0].numel() > 0:
|
||||||
# index_copy_(slot_mapping) only works when the inserted dimension
|
# index_copy_(slot_mapping) only works when the inserted dimension
|
||||||
# is 0. However, the KV cache in the Pallas backend has the shape
|
# is 0. However, the KV cache in the Pallas backend has the shape
|
||||||
# [num_kv_heads, num_blocks, block_size, head_size]. To make it
|
# [num_kv_heads, num_blocks, block_size, head_size]. To make it
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
"""A TPU worker class."""
|
"""A TPU worker class."""
|
||||||
import os
|
import os
|
||||||
from typing import Optional
|
from typing import Optional, Dict
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
import torch.distributed
|
||||||
@ -13,10 +13,13 @@ from vllm.distributed import (ensure_model_parallel_initialized,
|
|||||||
init_distributed_environment)
|
init_distributed_environment)
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor import set_random_seed
|
from vllm.model_executor import set_random_seed
|
||||||
|
from vllm.v1.kv_cache_interface import FullAttentionSpec
|
||||||
|
from vllm.v1.attention.backends.pallas import PallasAttentionBackend
|
||||||
from vllm.v1.core.scheduler import SchedulerOutput
|
from vllm.v1.core.scheduler import SchedulerOutput
|
||||||
from vllm.v1.outputs import ModelRunnerOutput
|
from vllm.v1.outputs import ModelRunnerOutput
|
||||||
from vllm.v1.worker.tpu_model_runner import ExecutionMode, TPUModelRunner
|
from vllm.v1.worker.tpu_model_runner import ExecutionMode, TPUModelRunner
|
||||||
from vllm.v1.worker.worker_base import WorkerBase
|
from vllm.v1.worker.worker_base import WorkerBase
|
||||||
|
from vllm.v1.utils import bind_kv_cache
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -74,20 +77,29 @@ class TPUWorker(WorkerBase):
|
|||||||
def determine_available_memory(self) -> int:
|
def determine_available_memory(self) -> int:
|
||||||
assert self.model_runner is not None
|
assert self.model_runner is not None
|
||||||
|
|
||||||
num_layers = self.model_config.get_num_layers(self.parallel_config)
|
kv_caches: Dict[str, torch.Tensor] = {}
|
||||||
|
kv_cache_spec = self.model_runner.get_kv_cache_spec()
|
||||||
|
for layer_name, layer_spec in kv_cache_spec.items():
|
||||||
|
if isinstance(layer_spec, FullAttentionSpec):
|
||||||
|
dtype = layer_spec.dtype
|
||||||
|
|
||||||
# use an empty tensor instead of `None`` to force Dynamo to pass
|
# Use an empty tensor instead of `None`` to force Dynamo to pass
|
||||||
# it by reference, rather by specializing on the value ``None``.
|
# it by reference, rather by specializing on the value ``None``.
|
||||||
# the `dtype` argument does not matter, and we use `float32` as
|
tpu_k_cache = torch.tensor([], dtype=dtype, device=self.device)
|
||||||
# a placeholder (it has wide hardware support).
|
tpu_v_cache = torch.tensor([], dtype=dtype, device=self.device)
|
||||||
kv_caches = [(torch.tensor([], dtype=torch.float32,
|
|
||||||
device=self.device),
|
kv_caches[layer_name] = (tpu_k_cache, tpu_v_cache)
|
||||||
torch.tensor([], dtype=torch.float32,
|
else:
|
||||||
device=self.device))
|
raise NotImplementedError
|
||||||
for _ in range(num_layers)]
|
|
||||||
|
runner_kv_caches = []
|
||||||
|
bind_kv_cache(
|
||||||
|
kv_caches,
|
||||||
|
self.vllm_config.compilation_config.static_forward_context,
|
||||||
|
runner_kv_caches)
|
||||||
|
|
||||||
self.model_runner.dummy_run(
|
self.model_runner.dummy_run(
|
||||||
kv_caches,
|
runner_kv_caches,
|
||||||
num_tokens=1,
|
num_tokens=1,
|
||||||
seq_len=self.scheduler_config.max_num_batched_tokens,
|
seq_len=self.scheduler_config.max_num_batched_tokens,
|
||||||
exec_mode=ExecutionMode.PREFILL,
|
exec_mode=ExecutionMode.PREFILL,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user