mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-28 23:07:14 +08:00
[MISC] Move bind_kv_cache to worker module (#20900)
Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
parent
697ef765ee
commit
1e9438e0b0
@ -3,7 +3,7 @@
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.v1.utils import bind_kv_cache
|
||||
from vllm.v1.worker.utils import bind_kv_cache
|
||||
|
||||
|
||||
def test_bind_kv_cache():
|
||||
|
||||
@ -4,7 +4,6 @@ import argparse
|
||||
import multiprocessing
|
||||
import time
|
||||
import weakref
|
||||
from collections import defaultdict
|
||||
from collections.abc import Sequence
|
||||
from multiprocessing import connection
|
||||
from multiprocessing.process import BaseProcess
|
||||
@ -14,14 +13,12 @@ from typing import (TYPE_CHECKING, Any, Callable, Generic, Optional, TypeVar,
|
||||
import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.models.utils import extract_layer_index
|
||||
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
|
||||
usage_message)
|
||||
from vllm.utils import (get_open_port, get_open_zmq_ipc_path, get_tcp_uri,
|
||||
kill_process_tree)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.v1.engine.coordinator import DPCoordinator
|
||||
from vllm.v1.engine.utils import (CoreEngineActorManager,
|
||||
CoreEngineProcManager)
|
||||
@ -275,51 +272,6 @@ def shutdown(procs: list[BaseProcess]):
|
||||
kill_process_tree(pid)
|
||||
|
||||
|
||||
def bind_kv_cache(
|
||||
kv_caches: dict[str, torch.Tensor],
|
||||
forward_context: dict[str, "Attention"],
|
||||
runner_kv_caches: list[torch.Tensor],
|
||||
) -> None:
|
||||
"""
|
||||
Bind the allocated KV cache to both ModelRunner and forward context so
|
||||
that the KV cache can be used in the forward pass.
|
||||
|
||||
This function:
|
||||
1) Fills the ModelRunner's kv cache list (`runner_kv_caches`) with
|
||||
kv_caches.
|
||||
2) Associates each attention layer in the `forward_context` with its
|
||||
corresponding KV cache in kv_caches.
|
||||
|
||||
Args:
|
||||
kv_caches: The allocated kv_caches with layer names as keys.
|
||||
forward_context: The global forward context containing all Attention
|
||||
layers with layer names as keys.
|
||||
runner_kv_caches: The kv_cache declared by ModelRunner.
|
||||
"""
|
||||
# Bind kv_caches to ModelRunner
|
||||
assert len(runner_kv_caches) == 0
|
||||
|
||||
# Convert kv_caches dict to a list of tensors in the order of layer_index.
|
||||
index2name = defaultdict(list)
|
||||
for layer_name in kv_caches:
|
||||
index2name[extract_layer_index(layer_name)].append(layer_name)
|
||||
|
||||
for layer_index in sorted(index2name.keys()):
|
||||
layer_names = index2name[layer_index]
|
||||
if len(layer_names) > 1:
|
||||
# One typical case is encoder-decoder model, e.g., bart.
|
||||
# The cross attention and self attention in the same decoder layer
|
||||
# has different layer_name but the same layer_index.
|
||||
raise NotImplementedError
|
||||
layer_name = layer_names[0]
|
||||
runner_kv_caches.append(kv_caches[layer_name])
|
||||
|
||||
# Bind kv_caches to forward context
|
||||
for layer_name, kv_cache in kv_caches.items():
|
||||
# NOTE: Use list because of v0 PP virtual engine.
|
||||
forward_context[layer_name].kv_cache = [kv_cache]
|
||||
|
||||
|
||||
def copy_slice(from_tensor: torch.Tensor, to_tensor: torch.Tensor,
|
||||
length: int) -> torch.Tensor:
|
||||
"""
|
||||
|
||||
@ -62,13 +62,13 @@ from vllm.v1.spec_decode.eagle import EagleProposer
|
||||
from vllm.v1.spec_decode.medusa import MedusaProposer
|
||||
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
||||
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
|
||||
from vllm.v1.utils import bind_kv_cache
|
||||
from vllm.v1.worker.block_table import BlockTable
|
||||
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
||||
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
|
||||
|
||||
from ..sample.logits_processor import LogitsProcessorManager
|
||||
from .utils import (gather_mm_placeholders, initialize_kv_cache_for_kv_sharing,
|
||||
from .utils import (bind_kv_cache, gather_mm_placeholders,
|
||||
initialize_kv_cache_for_kv_sharing,
|
||||
sanity_check_mm_encoder_outputs, scatter_mm_placeholders)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
@ -42,11 +42,10 @@ from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsLists,
|
||||
LogprobsTensors, ModelRunnerOutput)
|
||||
from vllm.v1.sample.tpu.metadata import TPUSupportedSamplingMetadata
|
||||
from vllm.v1.sample.tpu.sampler import Sampler as TPUSampler
|
||||
from vllm.v1.utils import bind_kv_cache
|
||||
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
|
||||
from vllm.v1.worker.tpu_input_batch import CachedRequestState, InputBatch
|
||||
|
||||
from .utils import (initialize_kv_cache_for_kv_sharing,
|
||||
from .utils import (bind_kv_cache, initialize_kv_cache_for_kv_sharing,
|
||||
sanity_check_mm_encoder_outputs)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
@ -25,8 +25,9 @@ from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.kv_cache_interface import (AttentionSpec, KVCacheConfig,
|
||||
KVCacheSpec)
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
from vllm.v1.utils import bind_kv_cache, report_usage_stats
|
||||
from vllm.v1.utils import report_usage_stats
|
||||
from vllm.v1.worker.tpu_model_runner import TPUModelRunner
|
||||
from vllm.v1.worker.utils import bind_kv_cache
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@ -1,12 +1,17 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Optional
|
||||
from collections import defaultdict
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.models.interfaces import MultiModalEmbeddings
|
||||
from vllm.model_executor.models.utils import extract_layer_index
|
||||
from vllm.v1.kv_cache_interface import KVCacheGroupSpec
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.layer import Attention
|
||||
|
||||
|
||||
def sanity_check_mm_encoder_outputs(
|
||||
mm_embeddings: MultiModalEmbeddings,
|
||||
@ -110,3 +115,48 @@ def initialize_kv_cache_for_kv_sharing(
|
||||
kv_caches[layer_name] = kv_caches[target_layer_name]
|
||||
group_idx = layer_to_kv_cache_group_idx[target_layer_name]
|
||||
kv_cache_groups[group_idx].layer_names.append(layer_name)
|
||||
|
||||
|
||||
def bind_kv_cache(
|
||||
kv_caches: dict[str, torch.Tensor],
|
||||
forward_context: dict[str, "Attention"],
|
||||
runner_kv_caches: list[torch.Tensor],
|
||||
) -> None:
|
||||
"""
|
||||
Bind the allocated KV cache to both ModelRunner and forward context so
|
||||
that the KV cache can be used in the forward pass.
|
||||
|
||||
This function:
|
||||
1) Fills the ModelRunner's kv cache list (`runner_kv_caches`) with
|
||||
kv_caches.
|
||||
2) Associates each attention layer in the `forward_context` with its
|
||||
corresponding KV cache in kv_caches.
|
||||
|
||||
Args:
|
||||
kv_caches: The allocated kv_caches with layer names as keys.
|
||||
forward_context: The global forward context containing all Attention
|
||||
layers with layer names as keys.
|
||||
runner_kv_caches: The kv_cache declared by ModelRunner.
|
||||
"""
|
||||
# Bind kv_caches to ModelRunner
|
||||
assert len(runner_kv_caches) == 0
|
||||
|
||||
# Convert kv_caches dict to a list of tensors in the order of layer_index.
|
||||
index2name = defaultdict(list)
|
||||
for layer_name in kv_caches:
|
||||
index2name[extract_layer_index(layer_name)].append(layer_name)
|
||||
|
||||
for layer_index in sorted(index2name.keys()):
|
||||
layer_names = index2name[layer_index]
|
||||
if len(layer_names) > 1:
|
||||
# One typical case is encoder-decoder model, e.g., bart.
|
||||
# The cross attention and self attention in the same decoder layer
|
||||
# has different layer_name but the same layer_index.
|
||||
raise NotImplementedError
|
||||
layer_name = layer_names[0]
|
||||
runner_kv_caches.append(kv_caches[layer_name])
|
||||
|
||||
# Bind kv_caches to forward context
|
||||
for layer_name, kv_cache in kv_caches.items():
|
||||
# NOTE: Use list because of v0 PP virtual engine.
|
||||
forward_context[layer_name].kv_cache = [kv_cache]
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user