[MISC] Move bind_kv_cache to worker module (#20900)

Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
wangxiyuan 2025-07-14 17:40:00 +08:00 committed by GitHub
parent 697ef765ee
commit 1e9438e0b0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 57 additions and 55 deletions

View File

@ -3,7 +3,7 @@
import torch import torch
from vllm.v1.utils import bind_kv_cache from vllm.v1.worker.utils import bind_kv_cache
def test_bind_kv_cache(): def test_bind_kv_cache():

View File

@ -4,7 +4,6 @@ import argparse
import multiprocessing import multiprocessing
import time import time
import weakref import weakref
from collections import defaultdict
from collections.abc import Sequence from collections.abc import Sequence
from multiprocessing import connection from multiprocessing import connection
from multiprocessing.process import BaseProcess from multiprocessing.process import BaseProcess
@ -14,14 +13,12 @@ from typing import (TYPE_CHECKING, Any, Callable, Generic, Optional, TypeVar,
import torch import torch
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.models.utils import extract_layer_index
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled, from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
usage_message) usage_message)
from vllm.utils import (get_open_port, get_open_zmq_ipc_path, get_tcp_uri, from vllm.utils import (get_open_port, get_open_zmq_ipc_path, get_tcp_uri,
kill_process_tree) kill_process_tree)
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.attention.layer import Attention
from vllm.v1.engine.coordinator import DPCoordinator from vllm.v1.engine.coordinator import DPCoordinator
from vllm.v1.engine.utils import (CoreEngineActorManager, from vllm.v1.engine.utils import (CoreEngineActorManager,
CoreEngineProcManager) CoreEngineProcManager)
@ -275,51 +272,6 @@ def shutdown(procs: list[BaseProcess]):
kill_process_tree(pid) kill_process_tree(pid)
def bind_kv_cache(
kv_caches: dict[str, torch.Tensor],
forward_context: dict[str, "Attention"],
runner_kv_caches: list[torch.Tensor],
) -> None:
"""
Bind the allocated KV cache to both ModelRunner and forward context so
that the KV cache can be used in the forward pass.
This function:
1) Fills the ModelRunner's kv cache list (`runner_kv_caches`) with
kv_caches.
2) Associates each attention layer in the `forward_context` with its
corresponding KV cache in kv_caches.
Args:
kv_caches: The allocated kv_caches with layer names as keys.
forward_context: The global forward context containing all Attention
layers with layer names as keys.
runner_kv_caches: The kv_cache declared by ModelRunner.
"""
# Bind kv_caches to ModelRunner
assert len(runner_kv_caches) == 0
# Convert kv_caches dict to a list of tensors in the order of layer_index.
index2name = defaultdict(list)
for layer_name in kv_caches:
index2name[extract_layer_index(layer_name)].append(layer_name)
for layer_index in sorted(index2name.keys()):
layer_names = index2name[layer_index]
if len(layer_names) > 1:
# One typical case is encoder-decoder model, e.g., bart.
# The cross attention and self attention in the same decoder layer
# has different layer_name but the same layer_index.
raise NotImplementedError
layer_name = layer_names[0]
runner_kv_caches.append(kv_caches[layer_name])
# Bind kv_caches to forward context
for layer_name, kv_cache in kv_caches.items():
# NOTE: Use list because of v0 PP virtual engine.
forward_context[layer_name].kv_cache = [kv_cache]
def copy_slice(from_tensor: torch.Tensor, to_tensor: torch.Tensor, def copy_slice(from_tensor: torch.Tensor, to_tensor: torch.Tensor,
length: int) -> torch.Tensor: length: int) -> torch.Tensor:
""" """

View File

@ -62,13 +62,13 @@ from vllm.v1.spec_decode.eagle import EagleProposer
from vllm.v1.spec_decode.medusa import MedusaProposer from vllm.v1.spec_decode.medusa import MedusaProposer
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm.v1.spec_decode.ngram_proposer import NgramProposer from vllm.v1.spec_decode.ngram_proposer import NgramProposer
from vllm.v1.utils import bind_kv_cache
from vllm.v1.worker.block_table import BlockTable from vllm.v1.worker.block_table import BlockTable
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
from ..sample.logits_processor import LogitsProcessorManager from ..sample.logits_processor import LogitsProcessorManager
from .utils import (gather_mm_placeholders, initialize_kv_cache_for_kv_sharing, from .utils import (bind_kv_cache, gather_mm_placeholders,
initialize_kv_cache_for_kv_sharing,
sanity_check_mm_encoder_outputs, scatter_mm_placeholders) sanity_check_mm_encoder_outputs, scatter_mm_placeholders)
if TYPE_CHECKING: if TYPE_CHECKING:

View File

@ -42,11 +42,10 @@ from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsLists,
LogprobsTensors, ModelRunnerOutput) LogprobsTensors, ModelRunnerOutput)
from vllm.v1.sample.tpu.metadata import TPUSupportedSamplingMetadata from vllm.v1.sample.tpu.metadata import TPUSupportedSamplingMetadata
from vllm.v1.sample.tpu.sampler import Sampler as TPUSampler from vllm.v1.sample.tpu.sampler import Sampler as TPUSampler
from vllm.v1.utils import bind_kv_cache
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
from vllm.v1.worker.tpu_input_batch import CachedRequestState, InputBatch from vllm.v1.worker.tpu_input_batch import CachedRequestState, InputBatch
from .utils import (initialize_kv_cache_for_kv_sharing, from .utils import (bind_kv_cache, initialize_kv_cache_for_kv_sharing,
sanity_check_mm_encoder_outputs) sanity_check_mm_encoder_outputs)
if TYPE_CHECKING: if TYPE_CHECKING:

View File

@ -25,8 +25,9 @@ from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import (AttentionSpec, KVCacheConfig, from vllm.v1.kv_cache_interface import (AttentionSpec, KVCacheConfig,
KVCacheSpec) KVCacheSpec)
from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.utils import bind_kv_cache, report_usage_stats from vllm.v1.utils import report_usage_stats
from vllm.v1.worker.tpu_model_runner import TPUModelRunner from vllm.v1.worker.tpu_model_runner import TPUModelRunner
from vllm.v1.worker.utils import bind_kv_cache
logger = init_logger(__name__) logger = init_logger(__name__)

View File

@ -1,12 +1,17 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional from collections import defaultdict
from typing import TYPE_CHECKING, Optional
import torch import torch
from vllm.model_executor.models.interfaces import MultiModalEmbeddings from vllm.model_executor.models.interfaces import MultiModalEmbeddings
from vllm.model_executor.models.utils import extract_layer_index
from vllm.v1.kv_cache_interface import KVCacheGroupSpec from vllm.v1.kv_cache_interface import KVCacheGroupSpec
if TYPE_CHECKING:
from vllm.attention.layer import Attention
def sanity_check_mm_encoder_outputs( def sanity_check_mm_encoder_outputs(
mm_embeddings: MultiModalEmbeddings, mm_embeddings: MultiModalEmbeddings,
@ -110,3 +115,48 @@ def initialize_kv_cache_for_kv_sharing(
kv_caches[layer_name] = kv_caches[target_layer_name] kv_caches[layer_name] = kv_caches[target_layer_name]
group_idx = layer_to_kv_cache_group_idx[target_layer_name] group_idx = layer_to_kv_cache_group_idx[target_layer_name]
kv_cache_groups[group_idx].layer_names.append(layer_name) kv_cache_groups[group_idx].layer_names.append(layer_name)
def bind_kv_cache(
kv_caches: dict[str, torch.Tensor],
forward_context: dict[str, "Attention"],
runner_kv_caches: list[torch.Tensor],
) -> None:
"""
Bind the allocated KV cache to both ModelRunner and forward context so
that the KV cache can be used in the forward pass.
This function:
1) Fills the ModelRunner's kv cache list (`runner_kv_caches`) with
kv_caches.
2) Associates each attention layer in the `forward_context` with its
corresponding KV cache in kv_caches.
Args:
kv_caches: The allocated kv_caches with layer names as keys.
forward_context: The global forward context containing all Attention
layers with layer names as keys.
runner_kv_caches: The kv_cache declared by ModelRunner.
"""
# Bind kv_caches to ModelRunner
assert len(runner_kv_caches) == 0
# Convert kv_caches dict to a list of tensors in the order of layer_index.
index2name = defaultdict(list)
for layer_name in kv_caches:
index2name[extract_layer_index(layer_name)].append(layer_name)
for layer_index in sorted(index2name.keys()):
layer_names = index2name[layer_index]
if len(layer_names) > 1:
# One typical case is encoder-decoder model, e.g., bart.
# The cross attention and self attention in the same decoder layer
# has different layer_name but the same layer_index.
raise NotImplementedError
layer_name = layer_names[0]
runner_kv_caches.append(kv_caches[layer_name])
# Bind kv_caches to forward context
for layer_name, kv_cache in kv_caches.items():
# NOTE: Use list because of v0 PP virtual engine.
forward_context[layer_name].kv_cache = [kv_cache]