Signed-off-by: Woosuk Kwon <woosuk@thinkingmachines.ai>
This commit is contained in:
Woosuk Kwon 2025-09-15 23:56:08 +00:00
parent 9f2becd3e6
commit dfc84b11a9
6 changed files with 178 additions and 79 deletions

View File

@ -30,6 +30,7 @@ class NewRequestData:
mm_features: list[MultiModalFeatureSpec]
sampling_params: Optional[SamplingParams]
pooling_params: Optional[PoolingParams]
block_ids: tuple[list[int], ...]
num_computed_tokens: int
lora_request: Optional[LoRARequest]
@ -45,6 +46,7 @@ class NewRequestData:
mm_features=request.mm_features,
sampling_params=request.sampling_params,
pooling_params=request.pooling_params,
block_ids=block_ids,
num_computed_tokens=request.num_computed_tokens,
lora_request=request.lora_request,
)
@ -55,6 +57,7 @@ class NewRequestData:
f"prompt_token_ids={self.prompt_token_ids},"
f"mm_features={self.mm_features},"
f"sampling_params={self.sampling_params},"
f"block_ids={self.block_ids},"
f"num_computed_tokens={self.num_computed_tokens},"
f"lora_request={self.lora_request}"
")")
@ -66,6 +69,7 @@ class NewRequestData:
f"prompt_token_ids_len={len(self.prompt_token_ids)},"
f"mm_features={self.mm_features},"
f"sampling_params={self.sampling_params},"
f"block_ids={self.block_ids},"
f"num_computed_tokens={self.num_computed_tokens},"
f"lora_request={self.lora_request}"
")")
@ -73,17 +77,52 @@ class NewRequestData:
@bc_linter_include
@dataclass
class SchedulerOutput:
class CachedRequestData:
req_ids: list[str]
cu_new_block_ids: tuple[np.ndarray, ...]
# If resumed_from_preemption is False, new_block_ids will be appended to
# the request's block IDs. If True, new_block_ids will be used as the
# request's block IDs instead of appending to the existing block IDs.
resumed_from_preemption: list[bool]
# NOTE(woosuk): new_token_ids is only used for pipeline parallelism.
# When PP is not used, new_token_ids will be empty.
new_token_ids: list[list[int]]
new_block_ids: list[Optional[tuple[list[int], ...]]]
num_computed_tokens: list[int]
@property
def num_reqs(self) -> int:
return len(self.req_ids)
@classmethod
def make_empty(cls) -> CachedRequestData:
return cls(
req_ids=[],
resumed_from_preemption=[],
new_token_ids=[],
new_block_ids=[],
num_computed_tokens=[],
)
@bc_linter_include
@dataclass
class SchedulerOutput:
# list of the requests that are scheduled for the first time.
# We cache the request's data in each worker process, so that we don't
# need to re-send it every scheduling step.
scheduled_new_reqs: list[NewRequestData]
# list of the requests that have been scheduled before.
# Since the request's data is already cached in the worker processes,
# we only send the diff to minimize the communication cost.
scheduled_cached_reqs: CachedRequestData
# req_id -> num_scheduled_tokens
# Number of tokens scheduled for each request.
num_scheduled_tokens: dict[str, int]
# Total number of tokens scheduled for all requests.
# Equal to sum(num_scheduled_tokens.values())
total_num_scheduled_tokens: int
# req_id -> spec_token_ids
# If a request does not have any spec decode tokens, it will not be
@ -97,11 +136,13 @@ class SchedulerOutput:
# This can be used for cascade attention.
num_common_prefix_blocks: list[int]
preempted_req_ids: set[str]
# Request IDs that are finished in between the previous and the current
# steps. This is used to notify the workers about the finished requests
# so that they can free the cached states for those requests.
finished_req_ids: set[str]
# list of mm_hash strings associated with the encoder outputs to be
# freed from the encoder cache.
free_encoder_mm_hashes: list[str]
# Dict of request ids to their index within the batch
# for filling the next token bitmask

View File

@ -2,8 +2,10 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm.attention.backends.abstract import AttentionType
from vllm.attention.backends.abstract import AttentionBackend, AttentionType
from vllm.attention.layer import Attention
from vllm.v1.attention.backends.utils import AttentionMetadataBuilder
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheSpec,
SlidingWindowSpec)
@ -40,8 +42,30 @@ def get_kv_cache_spec(
return kv_cache_spec
def init_attn_backend(vllm_config: VllmConfig):
def init_attn_backend(
kv_cache_config: KVCacheConfig,
vllm_config: VllmConfig,
device: torch.device,
):
attn_backends: dict[str, AttentionBackend] = {}
attn_metadata_builders: dict[str, AttentionMetadataBuilder] = {}
attn_layers = get_layers_from_vllm_config(vllm_config, Attention)
for kv_cache_group_spec in kv_cache_config.kv_cache_groups:
layer_names = kv_cache_group_spec.layer_names
any_layer_name = next(iter(layer_names))
attn_backend = attn_layers[any_layer_name].get_attn_backend()
attn_metadata_builder = attn_backend.get_builder_cls()(
kv_cache_group_spec.kv_cache_spec,
layer_names,
vllm_config,
device,
)
for layer_name in layer_names:
attn_backends[layer_name] = attn_backend
attn_metadata_builders[layer_name] = attn_metadata_builder
return attn_backends, attn_metadata_builders
def _allocate_kv_cache(
@ -68,13 +92,42 @@ def _allocate_kv_cache(
def _reshape_kv_cache(
kv_cache_config: KVCacheConfig,
kv_cache_raw_tensors: dict[str, torch.Tensor],
):
pass
attn_backends: dict[str, AttentionBackend],
) -> dict[str, torch.Tensor]:
kv_caches: dict[str, torch.Tensor] = {}
for kv_cache_group_spec in kv_cache_config.kv_cache_groups:
kv_cache_spec = kv_cache_group_spec.kv_cache_spec
for layer_name in kv_cache_group_spec.layer_names:
raw_tensor = kv_cache_raw_tensors[layer_name]
assert raw_tensor.numel() % kv_cache_spec.page_size_bytes == 0
num_blocks = (raw_tensor.numel() // kv_cache_spec.page_size_bytes)
attn_backend = attn_backends[layer_name]
kv_cache_shape = attn_backend.get_kv_cache_shape(
num_blocks, kv_cache_spec.block_size,
kv_cache_spec.num_kv_heads, kv_cache_spec.head_size)
dtype = kv_cache_spec.dtype
kv_cache_stride_order = attn_backend.get_kv_cache_stride_order()
kv_cache_shape = tuple(kv_cache_shape[i]
for i in kv_cache_stride_order)
inv_order = [
kv_cache_stride_order.index(i)
for i in range(len(kv_cache_stride_order))
]
raw_tensor = raw_tensor.view(dtype)
raw_tensor = raw_tensor.view(kv_cache_shape)
kv_caches[layer_name] = raw_tensor.permute(*inv_order)
return kv_caches
def init_kv_cache(
kv_cache_config: KVCacheConfig,
attn_backends: dict[str, AttentionBackend],
device: torch.device,
):
kv_cache_raw_tensors = _allocate_kv_cache(kv_cache_config, device)
kv_caches = _reshape_kv_cache(kv_cache_config, kv_cache_raw_tensors)
kv_caches = _reshape_kv_cache(kv_cache_config, kv_cache_raw_tensors, attn_backends)
return kv_caches

View File

@ -1,26 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import time
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model_loader
from vllm.utils import DeviceMemoryProfiler, GiB_bytes
logger = init_logger(__name__)
def load_model(vllm_config: VllmConfig):
time_before_load = time.perf_counter()
with DeviceMemoryProfiler() as m:
model_loader = get_model_loader(vllm_config.load_config)
logger.info("Loading model from scratch...")
model = model_loader.load_model(vllm_config=vllm_config,
model_config=vllm_config.model_config)
time_after_load = time.perf_counter()
logger.info("Model loading took %.4f GiB and %.6f seconds",
m.consumed_memory / GiB_bytes,
time_after_load - time_before_load)
return model

View File

@ -1,5 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import time
from copy import deepcopy
from typing import Any
@ -14,12 +15,14 @@ from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.sample.sampler import SamplerOutput
from vllm.v1.worker.gpu.attn_utils import get_kv_cache_spec, init_attn_backend
from vllm.v1.worker.gpu.attn_utils import get_kv_cache_spec, init_attn_backend, init_kv_cache
from vllm.v1.worker.utils import bind_kv_cache
from vllm.v1.worker.gpu.block_table import BlockTables
from vllm.v1.worker.gpu.init_utils import load_model
from vllm.v1.worker.gpu.input_batch import (InputBatch, InputBuffers,
prepare_inputs)
from vllm.v1.worker.gpu.sampler import Sampler
from vllm.model_executor.model_loader import get_model_loader
from vllm.utils import DeviceMemoryProfiler, GiB_bytes
from vllm.v1.worker.gpu.states import RequestState
logger = init_logger(__name__)
@ -52,6 +55,7 @@ class GPUModelRunner:
# Quantized KV cache.
self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[
self.cache_config.cache_dtype]
self.is_pooling_model = False
self.vocab_size = self.model_config.get_vocab_size()
self.max_model_len = self.model_config.max_model_len
@ -74,8 +78,27 @@ class GPUModelRunner:
)
self.sampler = Sampler()
def load_model(self) -> None:
self.model = load_model(self.vllm_config)
def load_model(self, eep_scale_up: bool = False) -> None:
time_before_load = time.perf_counter()
with DeviceMemoryProfiler() as m:
model_loader = get_model_loader(self.vllm_config.load_config)
logger.info("Loading model from scratch...")
self.model = model_loader.load_model(
vllm_config=self.vllm_config,
model_config=self.vllm_config.model_config,
)
time_after_load = time.perf_counter()
self.model_memory_usage = m.consumed_memory
logger.info("Model loading took %.4f GiB and %.6f seconds",
m.consumed_memory / GiB_bytes,
time_after_load - time_before_load)
def profile_run(self):
pass
def maybe_remove_all_loras(self, lora_config):
pass
def get_kv_cache_spec(self):
return get_kv_cache_spec(self.vllm_config, self.kv_cache_dtype)
@ -93,7 +116,20 @@ class GPUModelRunner:
device=self.device,
pin_memory=self.pin_memory,
)
self.attn_metadata_builders = init_attn_backend(self.vllm_config)
self.attn_backends, self.attn_metadata_builders = init_attn_backend(
self.kv_cache_config,
self.vllm_config,
self.device,
)
kv_caches = init_kv_cache(self.kv_cache_config, self.attn_backends, self.device)
self.kv_caches: list[torch.Tensor] = []
bind_kv_cache(
kv_caches,
self.compilation_config.static_forward_context,
self.kv_caches,
)
def update_states(self, scheduler_output: SchedulerOutput) -> None:
for req_id in scheduler_output.preempted_req_ids:
@ -291,9 +327,11 @@ class GPUModelRunner:
return None
num_prompt_tokens_scheduled = ...
if not np.any(num_prompt_tokens_scheduled > 0 & needs_prompt_logprobs):
if not np.any((num_prompt_tokens_scheduled > 0) & needs_prompt_logprobs):
# The request already computed prompt logprobs.
return None
# TODO
return
def postprocess(

View File

@ -1559,12 +1559,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
for layer_name in layer_names:
attn_backend = layers[layer_name].get_attn_backend()
if layer_name in self.kv_sharing_fast_prefill_eligible_layers:
attn_backend = create_fast_prefill_custom_backend(
"FastPrefill",
attn_backend,
)
key = attn_backend.full_cls_name()
attn_backends[key] = attn_backend
attn_backend_layers[key].append(layer_name)
@ -1726,7 +1720,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
corresponding memory buffer for KV cache.
"""
kv_caches: dict[str, torch.Tensor] = {}
has_attn, has_mamba = False, False
for kv_cache_spec, group in self._kv_cache_spec_attn_group_iterator():
attn_backend = group.backend
for layer_name in group.layer_names:
@ -1736,35 +1729,34 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
assert raw_tensor.numel() % kv_cache_spec.page_size_bytes == 0
num_blocks = (raw_tensor.numel() //
kv_cache_spec.page_size_bytes)
if isinstance(kv_cache_spec, AttentionSpec):
has_attn = True
kv_cache_shape = attn_backend.get_kv_cache_shape(
num_blocks, kv_cache_spec.block_size,
kv_cache_spec.num_kv_heads, kv_cache_spec.head_size)
dtype = kv_cache_spec.dtype
try:
kv_cache_stride_order = \
attn_backend.get_kv_cache_stride_order()
assert len(kv_cache_stride_order) == len(
kv_cache_shape)
except (AttributeError, NotImplementedError):
kv_cache_stride_order = tuple(
range(len(kv_cache_shape)))
# The allocation respects the backend-defined stride order
# to ensure the semantic remains consistent for each
# backend. We first obtain the generic kv cache shape and
# then permute it according to the stride order which could
# result in a non-contiguous tensor.
kv_cache_shape = tuple(kv_cache_shape[i]
for i in kv_cache_stride_order)
# Maintain original KV shape view.
inv_order = [
kv_cache_stride_order.index(i)
for i in range(len(kv_cache_stride_order))
]
kv_caches[layer_name] = kv_cache_raw_tensors[
layer_name].view(dtype).view(kv_cache_shape).permute(
*inv_order)
kv_cache_shape = attn_backend.get_kv_cache_shape(
num_blocks, kv_cache_spec.block_size,
kv_cache_spec.num_kv_heads, kv_cache_spec.head_size)
dtype = kv_cache_spec.dtype
try:
kv_cache_stride_order = \
attn_backend.get_kv_cache_stride_order()
assert len(kv_cache_stride_order) == len(
kv_cache_shape)
except (AttributeError, NotImplementedError):
kv_cache_stride_order = tuple(
range(len(kv_cache_shape)))
# The allocation respects the backend-defined stride order
# to ensure the semantic remains consistent for each
# backend. We first obtain the generic kv cache shape and
# then permute it according to the stride order which could
# result in a non-contiguous tensor.
kv_cache_shape = tuple(kv_cache_shape[i]
for i in kv_cache_stride_order)
# Maintain original KV shape view.
inv_order = [
kv_cache_stride_order.index(i)
for i in range(len(kv_cache_stride_order))
]
kv_caches[layer_name] = kv_cache_raw_tensors[
layer_name].view(dtype).view(kv_cache_shape).permute(
*inv_order)
return kv_caches

View File

@ -31,7 +31,8 @@ from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput,
DraftTokenIds, ModelRunnerOutput)
from vllm.v1.utils import report_usage_stats
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
# from vllm.v1.worker.gpu_model_runner import GPUModelRunner
from vllm.v1.worker.gpu.model_runner import GPUModelRunner
from vllm.v1.worker.utils import is_residual_scattered_for_sp
from vllm.v1.worker.worker_base import WorkerBase
@ -682,8 +683,8 @@ class Worker(WorkerBase):
self.model_runner.save_tensorized_model(
tensorizer_config=tensorizer_config, )
def shutdown(self) -> None:
self.model_runner.ensure_kv_transfer_shutdown()
# def shutdown(self) -> None:
# self.model_runner.ensure_kv_transfer_shutdown()
def init_worker_distributed_environment(