diff --git a/tests/v1/test_utils.py b/tests/v1/test_utils.py index a3df882a9e29e..fd0e630ce178a 100644 --- a/tests/v1/test_utils.py +++ b/tests/v1/test_utils.py @@ -3,7 +3,7 @@ import torch -from vllm.v1.utils import bind_kv_cache +from vllm.v1.worker.utils import bind_kv_cache def test_bind_kv_cache(): diff --git a/vllm/v1/utils.py b/vllm/v1/utils.py index 6b40cf6fd36d5..97fec4704b480 100644 --- a/vllm/v1/utils.py +++ b/vllm/v1/utils.py @@ -4,7 +4,6 @@ import argparse import multiprocessing import time import weakref -from collections import defaultdict from collections.abc import Sequence from multiprocessing import connection from multiprocessing.process import BaseProcess @@ -14,14 +13,12 @@ from typing import (TYPE_CHECKING, Any, Callable, Generic, Optional, TypeVar, import torch from vllm.logger import init_logger -from vllm.model_executor.models.utils import extract_layer_index from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled, usage_message) from vllm.utils import (get_open_port, get_open_zmq_ipc_path, get_tcp_uri, kill_process_tree) if TYPE_CHECKING: - from vllm.attention.layer import Attention from vllm.v1.engine.coordinator import DPCoordinator from vllm.v1.engine.utils import (CoreEngineActorManager, CoreEngineProcManager) @@ -275,51 +272,6 @@ def shutdown(procs: list[BaseProcess]): kill_process_tree(pid) -def bind_kv_cache( - kv_caches: dict[str, torch.Tensor], - forward_context: dict[str, "Attention"], - runner_kv_caches: list[torch.Tensor], -) -> None: - """ - Bind the allocated KV cache to both ModelRunner and forward context so - that the KV cache can be used in the forward pass. - - This function: - 1) Fills the ModelRunner's kv cache list (`runner_kv_caches`) with - kv_caches. - 2) Associates each attention layer in the `forward_context` with its - corresponding KV cache in kv_caches. - - Args: - kv_caches: The allocated kv_caches with layer names as keys. - forward_context: The global forward context containing all Attention - layers with layer names as keys. - runner_kv_caches: The kv_cache declared by ModelRunner. - """ - # Bind kv_caches to ModelRunner - assert len(runner_kv_caches) == 0 - - # Convert kv_caches dict to a list of tensors in the order of layer_index. - index2name = defaultdict(list) - for layer_name in kv_caches: - index2name[extract_layer_index(layer_name)].append(layer_name) - - for layer_index in sorted(index2name.keys()): - layer_names = index2name[layer_index] - if len(layer_names) > 1: - # One typical case is encoder-decoder model, e.g., bart. - # The cross attention and self attention in the same decoder layer - # has different layer_name but the same layer_index. - raise NotImplementedError - layer_name = layer_names[0] - runner_kv_caches.append(kv_caches[layer_name]) - - # Bind kv_caches to forward context - for layer_name, kv_cache in kv_caches.items(): - # NOTE: Use list because of v0 PP virtual engine. - forward_context[layer_name].kv_cache = [kv_cache] - - def copy_slice(from_tensor: torch.Tensor, to_tensor: torch.Tensor, length: int) -> torch.Tensor: """ diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 4551cb2df98ac..734df82589ac4 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -62,13 +62,13 @@ from vllm.v1.spec_decode.eagle import EagleProposer from vllm.v1.spec_decode.medusa import MedusaProposer from vllm.v1.spec_decode.metadata import SpecDecodeMetadata from vllm.v1.spec_decode.ngram_proposer import NgramProposer -from vllm.v1.utils import bind_kv_cache from vllm.v1.worker.block_table import BlockTable from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin from ..sample.logits_processor import LogitsProcessorManager -from .utils import (gather_mm_placeholders, initialize_kv_cache_for_kv_sharing, +from .utils import (bind_kv_cache, gather_mm_placeholders, + initialize_kv_cache_for_kv_sharing, sanity_check_mm_encoder_outputs, scatter_mm_placeholders) if TYPE_CHECKING: diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index eb96e56f495f1..82a203caf2b70 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -42,11 +42,10 @@ from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsLists, LogprobsTensors, ModelRunnerOutput) from vllm.v1.sample.tpu.metadata import TPUSupportedSamplingMetadata from vllm.v1.sample.tpu.sampler import Sampler as TPUSampler -from vllm.v1.utils import bind_kv_cache from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin from vllm.v1.worker.tpu_input_batch import CachedRequestState, InputBatch -from .utils import (initialize_kv_cache_for_kv_sharing, +from .utils import (bind_kv_cache, initialize_kv_cache_for_kv_sharing, sanity_check_mm_encoder_outputs) if TYPE_CHECKING: diff --git a/vllm/v1/worker/tpu_worker.py b/vllm/v1/worker/tpu_worker.py index c5336e9ad519e..c4bf40d665477 100644 --- a/vllm/v1/worker/tpu_worker.py +++ b/vllm/v1/worker/tpu_worker.py @@ -25,8 +25,9 @@ from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.kv_cache_interface import (AttentionSpec, KVCacheConfig, KVCacheSpec) from vllm.v1.outputs import ModelRunnerOutput -from vllm.v1.utils import bind_kv_cache, report_usage_stats +from vllm.v1.utils import report_usage_stats from vllm.v1.worker.tpu_model_runner import TPUModelRunner +from vllm.v1.worker.utils import bind_kv_cache logger = init_logger(__name__) diff --git a/vllm/v1/worker/utils.py b/vllm/v1/worker/utils.py index 70339ff2f0051..3ecb1d7dd6560 100644 --- a/vllm/v1/worker/utils.py +++ b/vllm/v1/worker/utils.py @@ -1,12 +1,17 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional +from collections import defaultdict +from typing import TYPE_CHECKING, Optional import torch from vllm.model_executor.models.interfaces import MultiModalEmbeddings +from vllm.model_executor.models.utils import extract_layer_index from vllm.v1.kv_cache_interface import KVCacheGroupSpec +if TYPE_CHECKING: + from vllm.attention.layer import Attention + def sanity_check_mm_encoder_outputs( mm_embeddings: MultiModalEmbeddings, @@ -110,3 +115,48 @@ def initialize_kv_cache_for_kv_sharing( kv_caches[layer_name] = kv_caches[target_layer_name] group_idx = layer_to_kv_cache_group_idx[target_layer_name] kv_cache_groups[group_idx].layer_names.append(layer_name) + + +def bind_kv_cache( + kv_caches: dict[str, torch.Tensor], + forward_context: dict[str, "Attention"], + runner_kv_caches: list[torch.Tensor], +) -> None: + """ + Bind the allocated KV cache to both ModelRunner and forward context so + that the KV cache can be used in the forward pass. + + This function: + 1) Fills the ModelRunner's kv cache list (`runner_kv_caches`) with + kv_caches. + 2) Associates each attention layer in the `forward_context` with its + corresponding KV cache in kv_caches. + + Args: + kv_caches: The allocated kv_caches with layer names as keys. + forward_context: The global forward context containing all Attention + layers with layer names as keys. + runner_kv_caches: The kv_cache declared by ModelRunner. + """ + # Bind kv_caches to ModelRunner + assert len(runner_kv_caches) == 0 + + # Convert kv_caches dict to a list of tensors in the order of layer_index. + index2name = defaultdict(list) + for layer_name in kv_caches: + index2name[extract_layer_index(layer_name)].append(layer_name) + + for layer_index in sorted(index2name.keys()): + layer_names = index2name[layer_index] + if len(layer_names) > 1: + # One typical case is encoder-decoder model, e.g., bart. + # The cross attention and self attention in the same decoder layer + # has different layer_name but the same layer_index. + raise NotImplementedError + layer_name = layer_names[0] + runner_kv_caches.append(kv_caches[layer_name]) + + # Bind kv_caches to forward context + for layer_name, kv_cache in kv_caches.items(): + # NOTE: Use list because of v0 PP virtual engine. + forward_context[layer_name].kv_cache = [kv_cache]