Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Woosuk Kwon 2025-09-18 15:21:30 -07:00
parent 67d8c0c21b
commit a98eff0762
2 changed files with 13 additions and 12 deletions

View File

@ -1,5 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any
import torch
from vllm.attention.backends.abstract import AttentionBackend, AttentionType
@ -8,6 +10,7 @@ from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.v1.attention.backends.utils import AttentionMetadataBuilder
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheSpec, SlidingWindowSpec)
from vllm.v1.worker.utils import bind_kv_cache
def get_kv_cache_spec(
@ -124,6 +127,8 @@ def _reshape_kv_cache(
def init_kv_cache(
runner_kv_caches: list[torch.Tensor],
forward_context: dict[str, Any],
kv_cache_config: KVCacheConfig,
attn_backends: dict[str, AttentionBackend],
device: torch.device,
@ -131,4 +136,4 @@ def init_kv_cache(
kv_cache_raw_tensors = _allocate_kv_cache(kv_cache_config, device)
kv_caches = _reshape_kv_cache(kv_cache_config, kv_cache_raw_tensors,
attn_backends)
return kv_caches
bind_kv_cache(forward_context, kv_caches, runner_kv_caches)

View File

@ -25,7 +25,6 @@ from vllm.v1.worker.gpu.input_batch import (InputBatch, InputBuffers,
prepare_inputs)
from vllm.v1.worker.gpu.sampler import Sampler
from vllm.v1.worker.gpu.states import RequestState
from vllm.v1.worker.utils import bind_kv_cache
logger = init_logger(__name__)
@ -98,9 +97,6 @@ class GPUModelRunner:
m.consumed_memory / GiB_bytes,
time_after_load - time_before_load)
def profile_run(self):
pass
def get_kv_cache_spec(self):
return get_kv_cache_spec(self.vllm_config, self.kv_cache_dtype)
@ -124,17 +120,14 @@ class GPUModelRunner:
self.device,
)
kv_caches = init_kv_cache(
self.kv_caches: list[torch.Tensor] = []
init_kv_cache(
self.kv_caches,
self.compilation_config.static_forward_context,
self.kv_cache_config,
self.attn_backends,
self.device,
)
self.kv_caches: list[torch.Tensor] = []
bind_kv_cache(
kv_caches,
self.compilation_config.static_forward_context,
self.kv_caches,
)
def _dummy_run(self, num_tokens: int, *args, **kwargs) -> None:
return None, None
@ -143,6 +136,9 @@ class GPUModelRunner:
**kwargs) -> None:
return None
def profile_run(self):
pass
def update_states(self, scheduler_output: SchedulerOutput) -> None:
# for req_id in scheduler_output.preempted_req_ids:
# self.req_states.remove_request(req_id)