[MISC] Move bind_kv_cache to worker module (#20900)

Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
wangxiyuan 2025-07-14 17:40:00 +08:00 committed by GitHub
parent 697ef765ee
commit 1e9438e0b0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 57 additions and 55 deletions

View File

@ -3,7 +3,7 @@
import torch
from vllm.v1.utils import bind_kv_cache
from vllm.v1.worker.utils import bind_kv_cache
def test_bind_kv_cache():

View File

@ -4,7 +4,6 @@ import argparse
import multiprocessing
import time
import weakref
from collections import defaultdict
from collections.abc import Sequence
from multiprocessing import connection
from multiprocessing.process import BaseProcess
@ -14,14 +13,12 @@ from typing import (TYPE_CHECKING, Any, Callable, Generic, Optional, TypeVar,
import torch
from vllm.logger import init_logger
from vllm.model_executor.models.utils import extract_layer_index
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
usage_message)
from vllm.utils import (get_open_port, get_open_zmq_ipc_path, get_tcp_uri,
kill_process_tree)
if TYPE_CHECKING:
from vllm.attention.layer import Attention
from vllm.v1.engine.coordinator import DPCoordinator
from vllm.v1.engine.utils import (CoreEngineActorManager,
CoreEngineProcManager)
@ -275,51 +272,6 @@ def shutdown(procs: list[BaseProcess]):
kill_process_tree(pid)
def bind_kv_cache(
kv_caches: dict[str, torch.Tensor],
forward_context: dict[str, "Attention"],
runner_kv_caches: list[torch.Tensor],
) -> None:
"""
Bind the allocated KV cache to both ModelRunner and forward context so
that the KV cache can be used in the forward pass.
This function:
1) Fills the ModelRunner's kv cache list (`runner_kv_caches`) with
kv_caches.
2) Associates each attention layer in the `forward_context` with its
corresponding KV cache in kv_caches.
Args:
kv_caches: The allocated kv_caches with layer names as keys.
forward_context: The global forward context containing all Attention
layers with layer names as keys.
runner_kv_caches: The kv_cache declared by ModelRunner.
"""
# Bind kv_caches to ModelRunner
assert len(runner_kv_caches) == 0
# Convert kv_caches dict to a list of tensors in the order of layer_index.
index2name = defaultdict(list)
for layer_name in kv_caches:
index2name[extract_layer_index(layer_name)].append(layer_name)
for layer_index in sorted(index2name.keys()):
layer_names = index2name[layer_index]
if len(layer_names) > 1:
# One typical case is encoder-decoder model, e.g., bart.
# The cross attention and self attention in the same decoder layer
# has different layer_name but the same layer_index.
raise NotImplementedError
layer_name = layer_names[0]
runner_kv_caches.append(kv_caches[layer_name])
# Bind kv_caches to forward context
for layer_name, kv_cache in kv_caches.items():
# NOTE: Use list because of v0 PP virtual engine.
forward_context[layer_name].kv_cache = [kv_cache]
def copy_slice(from_tensor: torch.Tensor, to_tensor: torch.Tensor,
length: int) -> torch.Tensor:
"""

View File

@ -62,13 +62,13 @@ from vllm.v1.spec_decode.eagle import EagleProposer
from vllm.v1.spec_decode.medusa import MedusaProposer
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
from vllm.v1.utils import bind_kv_cache
from vllm.v1.worker.block_table import BlockTable
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
from ..sample.logits_processor import LogitsProcessorManager
from .utils import (gather_mm_placeholders, initialize_kv_cache_for_kv_sharing,
from .utils import (bind_kv_cache, gather_mm_placeholders,
initialize_kv_cache_for_kv_sharing,
sanity_check_mm_encoder_outputs, scatter_mm_placeholders)
if TYPE_CHECKING:

View File

@ -42,11 +42,10 @@ from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsLists,
LogprobsTensors, ModelRunnerOutput)
from vllm.v1.sample.tpu.metadata import TPUSupportedSamplingMetadata
from vllm.v1.sample.tpu.sampler import Sampler as TPUSampler
from vllm.v1.utils import bind_kv_cache
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
from vllm.v1.worker.tpu_input_batch import CachedRequestState, InputBatch
from .utils import (initialize_kv_cache_for_kv_sharing,
from .utils import (bind_kv_cache, initialize_kv_cache_for_kv_sharing,
sanity_check_mm_encoder_outputs)
if TYPE_CHECKING:

View File

@ -25,8 +25,9 @@ from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import (AttentionSpec, KVCacheConfig,
KVCacheSpec)
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.utils import bind_kv_cache, report_usage_stats
from vllm.v1.utils import report_usage_stats
from vllm.v1.worker.tpu_model_runner import TPUModelRunner
from vllm.v1.worker.utils import bind_kv_cache
logger = init_logger(__name__)

View File

@ -1,12 +1,17 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
from collections import defaultdict
from typing import TYPE_CHECKING, Optional
import torch
from vllm.model_executor.models.interfaces import MultiModalEmbeddings
from vllm.model_executor.models.utils import extract_layer_index
from vllm.v1.kv_cache_interface import KVCacheGroupSpec
if TYPE_CHECKING:
from vllm.attention.layer import Attention
def sanity_check_mm_encoder_outputs(
mm_embeddings: MultiModalEmbeddings,
@ -110,3 +115,48 @@ def initialize_kv_cache_for_kv_sharing(
kv_caches[layer_name] = kv_caches[target_layer_name]
group_idx = layer_to_kv_cache_group_idx[target_layer_name]
kv_cache_groups[group_idx].layer_names.append(layer_name)
def bind_kv_cache(
kv_caches: dict[str, torch.Tensor],
forward_context: dict[str, "Attention"],
runner_kv_caches: list[torch.Tensor],
) -> None:
"""
Bind the allocated KV cache to both ModelRunner and forward context so
that the KV cache can be used in the forward pass.
This function:
1) Fills the ModelRunner's kv cache list (`runner_kv_caches`) with
kv_caches.
2) Associates each attention layer in the `forward_context` with its
corresponding KV cache in kv_caches.
Args:
kv_caches: The allocated kv_caches with layer names as keys.
forward_context: The global forward context containing all Attention
layers with layer names as keys.
runner_kv_caches: The kv_cache declared by ModelRunner.
"""
# Bind kv_caches to ModelRunner
assert len(runner_kv_caches) == 0
# Convert kv_caches dict to a list of tensors in the order of layer_index.
index2name = defaultdict(list)
for layer_name in kv_caches:
index2name[extract_layer_index(layer_name)].append(layer_name)
for layer_index in sorted(index2name.keys()):
layer_names = index2name[layer_index]
if len(layer_names) > 1:
# One typical case is encoder-decoder model, e.g., bart.
# The cross attention and self attention in the same decoder layer
# has different layer_name but the same layer_index.
raise NotImplementedError
layer_name = layer_names[0]
runner_kv_caches.append(kv_caches[layer_name])
# Bind kv_caches to forward context
for layer_name, kv_cache in kv_caches.items():
# NOTE: Use list because of v0 PP virtual engine.
forward_context[layer_name].kv_cache = [kv_cache]