From a98eff0762fd75989b680ea3a40d80084737d889 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Thu, 18 Sep 2025 15:21:30 -0700 Subject: [PATCH] minor Signed-off-by: Woosuk Kwon --- vllm/v1/worker/gpu/attn_utils.py | 7 ++++++- vllm/v1/worker/gpu/model_runner.py | 18 +++++++----------- 2 files changed, 13 insertions(+), 12 deletions(-) diff --git a/vllm/v1/worker/gpu/attn_utils.py b/vllm/v1/worker/gpu/attn_utils.py index 87cc9b610b918..631bcd8023526 100644 --- a/vllm/v1/worker/gpu/attn_utils.py +++ b/vllm/v1/worker/gpu/attn_utils.py @@ -1,5 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Any + import torch from vllm.attention.backends.abstract import AttentionBackend, AttentionType @@ -8,6 +10,7 @@ from vllm.config import VllmConfig, get_layers_from_vllm_config from vllm.v1.attention.backends.utils import AttentionMetadataBuilder from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, KVCacheSpec, SlidingWindowSpec) +from vllm.v1.worker.utils import bind_kv_cache def get_kv_cache_spec( @@ -124,6 +127,8 @@ def _reshape_kv_cache( def init_kv_cache( + runner_kv_caches: list[torch.Tensor], + forward_context: dict[str, Any], kv_cache_config: KVCacheConfig, attn_backends: dict[str, AttentionBackend], device: torch.device, @@ -131,4 +136,4 @@ def init_kv_cache( kv_cache_raw_tensors = _allocate_kv_cache(kv_cache_config, device) kv_caches = _reshape_kv_cache(kv_cache_config, kv_cache_raw_tensors, attn_backends) - return kv_caches + bind_kv_cache(forward_context, kv_caches, runner_kv_caches) diff --git a/vllm/v1/worker/gpu/model_runner.py b/vllm/v1/worker/gpu/model_runner.py index 3339544c4eff5..d591b3285297e 100644 --- a/vllm/v1/worker/gpu/model_runner.py +++ b/vllm/v1/worker/gpu/model_runner.py @@ -25,7 +25,6 @@ from vllm.v1.worker.gpu.input_batch import (InputBatch, InputBuffers, prepare_inputs) from vllm.v1.worker.gpu.sampler import Sampler from vllm.v1.worker.gpu.states import RequestState -from vllm.v1.worker.utils import bind_kv_cache logger = init_logger(__name__) @@ -98,9 +97,6 @@ class GPUModelRunner: m.consumed_memory / GiB_bytes, time_after_load - time_before_load) - def profile_run(self): - pass - def get_kv_cache_spec(self): return get_kv_cache_spec(self.vllm_config, self.kv_cache_dtype) @@ -124,17 +120,14 @@ class GPUModelRunner: self.device, ) - kv_caches = init_kv_cache( + self.kv_caches: list[torch.Tensor] = [] + init_kv_cache( + self.kv_caches, + self.compilation_config.static_forward_context, self.kv_cache_config, self.attn_backends, self.device, ) - self.kv_caches: list[torch.Tensor] = [] - bind_kv_cache( - kv_caches, - self.compilation_config.static_forward_context, - self.kv_caches, - ) def _dummy_run(self, num_tokens: int, *args, **kwargs) -> None: return None, None @@ -143,6 +136,9 @@ class GPUModelRunner: **kwargs) -> None: return None + def profile_run(self): + pass + def update_states(self, scheduler_output: SchedulerOutput) -> None: # for req_id in scheduler_output.preempted_req_ids: # self.req_states.remove_request(req_id)