mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-10 05:15:42 +08:00
[MISC] Move bind_kv_cache to worker module (#20900)
Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
parent
697ef765ee
commit
1e9438e0b0
@ -3,7 +3,7 @@
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.v1.utils import bind_kv_cache
|
from vllm.v1.worker.utils import bind_kv_cache
|
||||||
|
|
||||||
|
|
||||||
def test_bind_kv_cache():
|
def test_bind_kv_cache():
|
||||||
|
|||||||
@ -4,7 +4,6 @@ import argparse
|
|||||||
import multiprocessing
|
import multiprocessing
|
||||||
import time
|
import time
|
||||||
import weakref
|
import weakref
|
||||||
from collections import defaultdict
|
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from multiprocessing import connection
|
from multiprocessing import connection
|
||||||
from multiprocessing.process import BaseProcess
|
from multiprocessing.process import BaseProcess
|
||||||
@ -14,14 +13,12 @@ from typing import (TYPE_CHECKING, Any, Callable, Generic, Optional, TypeVar,
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.models.utils import extract_layer_index
|
|
||||||
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
|
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
|
||||||
usage_message)
|
usage_message)
|
||||||
from vllm.utils import (get_open_port, get_open_zmq_ipc_path, get_tcp_uri,
|
from vllm.utils import (get_open_port, get_open_zmq_ipc_path, get_tcp_uri,
|
||||||
kill_process_tree)
|
kill_process_tree)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.attention.layer import Attention
|
|
||||||
from vllm.v1.engine.coordinator import DPCoordinator
|
from vllm.v1.engine.coordinator import DPCoordinator
|
||||||
from vllm.v1.engine.utils import (CoreEngineActorManager,
|
from vllm.v1.engine.utils import (CoreEngineActorManager,
|
||||||
CoreEngineProcManager)
|
CoreEngineProcManager)
|
||||||
@ -275,51 +272,6 @@ def shutdown(procs: list[BaseProcess]):
|
|||||||
kill_process_tree(pid)
|
kill_process_tree(pid)
|
||||||
|
|
||||||
|
|
||||||
def bind_kv_cache(
|
|
||||||
kv_caches: dict[str, torch.Tensor],
|
|
||||||
forward_context: dict[str, "Attention"],
|
|
||||||
runner_kv_caches: list[torch.Tensor],
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
Bind the allocated KV cache to both ModelRunner and forward context so
|
|
||||||
that the KV cache can be used in the forward pass.
|
|
||||||
|
|
||||||
This function:
|
|
||||||
1) Fills the ModelRunner's kv cache list (`runner_kv_caches`) with
|
|
||||||
kv_caches.
|
|
||||||
2) Associates each attention layer in the `forward_context` with its
|
|
||||||
corresponding KV cache in kv_caches.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
kv_caches: The allocated kv_caches with layer names as keys.
|
|
||||||
forward_context: The global forward context containing all Attention
|
|
||||||
layers with layer names as keys.
|
|
||||||
runner_kv_caches: The kv_cache declared by ModelRunner.
|
|
||||||
"""
|
|
||||||
# Bind kv_caches to ModelRunner
|
|
||||||
assert len(runner_kv_caches) == 0
|
|
||||||
|
|
||||||
# Convert kv_caches dict to a list of tensors in the order of layer_index.
|
|
||||||
index2name = defaultdict(list)
|
|
||||||
for layer_name in kv_caches:
|
|
||||||
index2name[extract_layer_index(layer_name)].append(layer_name)
|
|
||||||
|
|
||||||
for layer_index in sorted(index2name.keys()):
|
|
||||||
layer_names = index2name[layer_index]
|
|
||||||
if len(layer_names) > 1:
|
|
||||||
# One typical case is encoder-decoder model, e.g., bart.
|
|
||||||
# The cross attention and self attention in the same decoder layer
|
|
||||||
# has different layer_name but the same layer_index.
|
|
||||||
raise NotImplementedError
|
|
||||||
layer_name = layer_names[0]
|
|
||||||
runner_kv_caches.append(kv_caches[layer_name])
|
|
||||||
|
|
||||||
# Bind kv_caches to forward context
|
|
||||||
for layer_name, kv_cache in kv_caches.items():
|
|
||||||
# NOTE: Use list because of v0 PP virtual engine.
|
|
||||||
forward_context[layer_name].kv_cache = [kv_cache]
|
|
||||||
|
|
||||||
|
|
||||||
def copy_slice(from_tensor: torch.Tensor, to_tensor: torch.Tensor,
|
def copy_slice(from_tensor: torch.Tensor, to_tensor: torch.Tensor,
|
||||||
length: int) -> torch.Tensor:
|
length: int) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -62,13 +62,13 @@ from vllm.v1.spec_decode.eagle import EagleProposer
|
|||||||
from vllm.v1.spec_decode.medusa import MedusaProposer
|
from vllm.v1.spec_decode.medusa import MedusaProposer
|
||||||
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
||||||
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
|
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
|
||||||
from vllm.v1.utils import bind_kv_cache
|
|
||||||
from vllm.v1.worker.block_table import BlockTable
|
from vllm.v1.worker.block_table import BlockTable
|
||||||
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
||||||
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
|
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
|
||||||
|
|
||||||
from ..sample.logits_processor import LogitsProcessorManager
|
from ..sample.logits_processor import LogitsProcessorManager
|
||||||
from .utils import (gather_mm_placeholders, initialize_kv_cache_for_kv_sharing,
|
from .utils import (bind_kv_cache, gather_mm_placeholders,
|
||||||
|
initialize_kv_cache_for_kv_sharing,
|
||||||
sanity_check_mm_encoder_outputs, scatter_mm_placeholders)
|
sanity_check_mm_encoder_outputs, scatter_mm_placeholders)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
|||||||
@ -42,11 +42,10 @@ from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsLists,
|
|||||||
LogprobsTensors, ModelRunnerOutput)
|
LogprobsTensors, ModelRunnerOutput)
|
||||||
from vllm.v1.sample.tpu.metadata import TPUSupportedSamplingMetadata
|
from vllm.v1.sample.tpu.metadata import TPUSupportedSamplingMetadata
|
||||||
from vllm.v1.sample.tpu.sampler import Sampler as TPUSampler
|
from vllm.v1.sample.tpu.sampler import Sampler as TPUSampler
|
||||||
from vllm.v1.utils import bind_kv_cache
|
|
||||||
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
|
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
|
||||||
from vllm.v1.worker.tpu_input_batch import CachedRequestState, InputBatch
|
from vllm.v1.worker.tpu_input_batch import CachedRequestState, InputBatch
|
||||||
|
|
||||||
from .utils import (initialize_kv_cache_for_kv_sharing,
|
from .utils import (bind_kv_cache, initialize_kv_cache_for_kv_sharing,
|
||||||
sanity_check_mm_encoder_outputs)
|
sanity_check_mm_encoder_outputs)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
|||||||
@ -25,8 +25,9 @@ from vllm.v1.core.sched.output import SchedulerOutput
|
|||||||
from vllm.v1.kv_cache_interface import (AttentionSpec, KVCacheConfig,
|
from vllm.v1.kv_cache_interface import (AttentionSpec, KVCacheConfig,
|
||||||
KVCacheSpec)
|
KVCacheSpec)
|
||||||
from vllm.v1.outputs import ModelRunnerOutput
|
from vllm.v1.outputs import ModelRunnerOutput
|
||||||
from vllm.v1.utils import bind_kv_cache, report_usage_stats
|
from vllm.v1.utils import report_usage_stats
|
||||||
from vllm.v1.worker.tpu_model_runner import TPUModelRunner
|
from vllm.v1.worker.tpu_model_runner import TPUModelRunner
|
||||||
|
from vllm.v1.worker.utils import bind_kv_cache
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|||||||
@ -1,12 +1,17 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
from typing import Optional
|
from collections import defaultdict
|
||||||
|
from typing import TYPE_CHECKING, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.model_executor.models.interfaces import MultiModalEmbeddings
|
from vllm.model_executor.models.interfaces import MultiModalEmbeddings
|
||||||
|
from vllm.model_executor.models.utils import extract_layer_index
|
||||||
from vllm.v1.kv_cache_interface import KVCacheGroupSpec
|
from vllm.v1.kv_cache_interface import KVCacheGroupSpec
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from vllm.attention.layer import Attention
|
||||||
|
|
||||||
|
|
||||||
def sanity_check_mm_encoder_outputs(
|
def sanity_check_mm_encoder_outputs(
|
||||||
mm_embeddings: MultiModalEmbeddings,
|
mm_embeddings: MultiModalEmbeddings,
|
||||||
@ -110,3 +115,48 @@ def initialize_kv_cache_for_kv_sharing(
|
|||||||
kv_caches[layer_name] = kv_caches[target_layer_name]
|
kv_caches[layer_name] = kv_caches[target_layer_name]
|
||||||
group_idx = layer_to_kv_cache_group_idx[target_layer_name]
|
group_idx = layer_to_kv_cache_group_idx[target_layer_name]
|
||||||
kv_cache_groups[group_idx].layer_names.append(layer_name)
|
kv_cache_groups[group_idx].layer_names.append(layer_name)
|
||||||
|
|
||||||
|
|
||||||
|
def bind_kv_cache(
|
||||||
|
kv_caches: dict[str, torch.Tensor],
|
||||||
|
forward_context: dict[str, "Attention"],
|
||||||
|
runner_kv_caches: list[torch.Tensor],
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Bind the allocated KV cache to both ModelRunner and forward context so
|
||||||
|
that the KV cache can be used in the forward pass.
|
||||||
|
|
||||||
|
This function:
|
||||||
|
1) Fills the ModelRunner's kv cache list (`runner_kv_caches`) with
|
||||||
|
kv_caches.
|
||||||
|
2) Associates each attention layer in the `forward_context` with its
|
||||||
|
corresponding KV cache in kv_caches.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
kv_caches: The allocated kv_caches with layer names as keys.
|
||||||
|
forward_context: The global forward context containing all Attention
|
||||||
|
layers with layer names as keys.
|
||||||
|
runner_kv_caches: The kv_cache declared by ModelRunner.
|
||||||
|
"""
|
||||||
|
# Bind kv_caches to ModelRunner
|
||||||
|
assert len(runner_kv_caches) == 0
|
||||||
|
|
||||||
|
# Convert kv_caches dict to a list of tensors in the order of layer_index.
|
||||||
|
index2name = defaultdict(list)
|
||||||
|
for layer_name in kv_caches:
|
||||||
|
index2name[extract_layer_index(layer_name)].append(layer_name)
|
||||||
|
|
||||||
|
for layer_index in sorted(index2name.keys()):
|
||||||
|
layer_names = index2name[layer_index]
|
||||||
|
if len(layer_names) > 1:
|
||||||
|
# One typical case is encoder-decoder model, e.g., bart.
|
||||||
|
# The cross attention and self attention in the same decoder layer
|
||||||
|
# has different layer_name but the same layer_index.
|
||||||
|
raise NotImplementedError
|
||||||
|
layer_name = layer_names[0]
|
||||||
|
runner_kv_caches.append(kv_caches[layer_name])
|
||||||
|
|
||||||
|
# Bind kv_caches to forward context
|
||||||
|
for layer_name, kv_cache in kv_caches.items():
|
||||||
|
# NOTE: Use list because of v0 PP virtual engine.
|
||||||
|
forward_context[layer_name].kv_cache = [kv_cache]
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user