mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 07:04:53 +08:00
[V1] [P/D] Refactor KV Connector Path (#21980)
Signed-off-by: David Ben-David <davidb@pliops.com> Co-authored-by: David Ben-David <davidb@pliops.com>
This commit is contained in:
parent
24d1dffbeb
commit
aefeea0fde
@ -4,7 +4,7 @@ from concurrent.futures import Future
|
||||
from typing import Optional
|
||||
|
||||
from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput
|
||||
|
||||
|
||||
class DummyModelRunnerOutput(ModelRunnerOutput):
|
||||
@ -12,8 +12,16 @@ class DummyModelRunnerOutput(ModelRunnerOutput):
|
||||
def __init__(self,
|
||||
finished_sending: Optional[set[str]] = None,
|
||||
finished_recving: Optional[set[str]] = None):
|
||||
self.finished_sending = finished_sending
|
||||
self.finished_recving = finished_recving
|
||||
self.kv_connector_output = KVConnectorOutput(
|
||||
finished_sending=finished_sending,
|
||||
finished_recving=finished_recving,
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
f"DummyModelRunnerOutput("
|
||||
f"finished_sending={self.kv_connector_output.finished_sending},"
|
||||
f"finished_recving={self.kv_connector_output.finished_recving})")
|
||||
|
||||
|
||||
def test_aggregate_workers_output():
|
||||
@ -27,6 +35,7 @@ def test_aggregate_workers_output():
|
||||
aggregated = aggregator.aggregate([output1, output2])
|
||||
|
||||
assert aggregated is output1
|
||||
aggregated = aggregated.kv_connector_output
|
||||
assert aggregated.finished_sending is None
|
||||
assert aggregated.finished_recving is None
|
||||
|
||||
@ -38,6 +47,7 @@ def test_aggregate_workers_output():
|
||||
aggregated = aggregator.aggregate([output1, output2])
|
||||
|
||||
assert aggregated is output1
|
||||
aggregated = aggregated.kv_connector_output
|
||||
assert aggregated.finished_sending == {'req1'}
|
||||
assert aggregated.finished_recving is None
|
||||
|
||||
@ -49,6 +59,7 @@ def test_aggregate_workers_output():
|
||||
aggregated = aggregator.aggregate([output1, output2])
|
||||
|
||||
assert aggregated is output1
|
||||
aggregated = aggregated.kv_connector_output
|
||||
assert aggregated.finished_sending is None
|
||||
assert aggregated.finished_recving == {'req2'}
|
||||
|
||||
@ -70,6 +81,7 @@ def test_async_aggregate_workers_output():
|
||||
assert result_future.done()
|
||||
aggregated = result_future.result()
|
||||
assert aggregated is output1
|
||||
aggregated = aggregated.kv_connector_output
|
||||
assert aggregated.finished_sending is None
|
||||
assert aggregated.finished_recving is None
|
||||
|
||||
@ -87,6 +99,7 @@ def test_async_aggregate_workers_output():
|
||||
assert result_future.done()
|
||||
aggregated = result_future.result()
|
||||
assert aggregated is output1
|
||||
aggregated = aggregated.kv_connector_output
|
||||
assert aggregated.finished_sending == {'req1'}
|
||||
assert aggregated.finished_recving is None
|
||||
|
||||
@ -104,5 +117,6 @@ def test_async_aggregate_workers_output():
|
||||
assert result_future.done()
|
||||
aggregated = result_future.result()
|
||||
assert aggregated is output1
|
||||
aggregated = aggregated.kv_connector_output
|
||||
assert aggregated.finished_sending is None
|
||||
assert aggregated.finished_recving == {'req2'}
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import copy
|
||||
|
||||
from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT
|
||||
from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, KVConnectorOutput
|
||||
from vllm.v1.request import FinishReason, RequestStatus
|
||||
|
||||
from .utils import (assert_scheduler_empty, create_model_runner_output,
|
||||
@ -86,7 +86,8 @@ def test_basic_lifecycle():
|
||||
|
||||
# (3b): execute_model()
|
||||
model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
|
||||
model_runner_output.finished_sending = [request_id]
|
||||
model_runner_output.kv_connector_output = KVConnectorOutput(
|
||||
finished_sending=[request_id])
|
||||
|
||||
# (3c): update_from_output()
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
@ -176,7 +177,8 @@ def test_prefix_cache_lifecycle():
|
||||
scheduler_output = scheduler.schedule()
|
||||
scheduler.schedule()
|
||||
model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
|
||||
model_runner_output.finished_sending = [request_remote.request_id]
|
||||
model_runner_output.kv_connector_output = KVConnectorOutput(
|
||||
finished_sending=[request_remote.request_id])
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
_ = scheduler.schedule()
|
||||
assert_scheduler_empty(scheduler)
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import copy
|
||||
|
||||
from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT
|
||||
from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, KVConnectorOutput
|
||||
from vllm.v1.request import FinishReason, RequestStatus
|
||||
|
||||
from .utils import (assert_scheduler_empty, create_model_runner_output,
|
||||
@ -72,7 +72,8 @@ def test_basic_lifecycle():
|
||||
|
||||
# (2b): forward(): request finishes recv.
|
||||
model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
|
||||
model_runner_output.finished_recving = [request_id]
|
||||
model_runner_output.kv_connector_output = KVConnectorOutput(
|
||||
finished_recving=[request_id])
|
||||
|
||||
# (2c): update_from_output():
|
||||
engine_core_outputs = scheduler.update_from_output(scheduler_output,
|
||||
@ -309,7 +310,8 @@ def test_full_block_prompt():
|
||||
# # STEP (2): Recv.
|
||||
scheduler_output = scheduler.schedule()
|
||||
model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
|
||||
model_runner_output.finished_recving = [request_id]
|
||||
model_runner_output.kv_connector_output = KVConnectorOutput(
|
||||
finished_recving=[request_id])
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
assert len(scheduler.waiting) == 1
|
||||
assert (request_id in scheduler.finished_recving_kv_req_ids)
|
||||
|
||||
@ -17,7 +17,7 @@ from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||
from vllm.v1.core.sched.scheduler import Scheduler
|
||||
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
||||
KVCacheGroupSpec)
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput
|
||||
from vllm.v1.request import Request
|
||||
from vllm.v1.structured_output import StructuredOutputManager
|
||||
|
||||
@ -188,8 +188,10 @@ def create_model_runner_output(
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=None,
|
||||
finished_sending=finished_sending,
|
||||
finished_recving=finished_recving,
|
||||
kv_connector_output=KVConnectorOutput(
|
||||
finished_sending=finished_sending,
|
||||
finished_recving=finished_recving,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -16,7 +16,7 @@ from vllm.config import VllmConfig, get_current_vllm_config
|
||||
from vllm.distributed.kv_transfer.kv_connector.factory import (
|
||||
KVConnectorFactory)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -129,7 +129,7 @@ class KVOutputAggregator:
|
||||
def aggregate(self,
|
||||
outputs: list[ModelRunnerOutput],
|
||||
output_rank: int = 0) -> ModelRunnerOutput:
|
||||
# aggregate finished_sending, finished_recving from all workers
|
||||
# aggregate kv_connector_output from all workers
|
||||
|
||||
def update_finished_set(req_ids: Optional[set[str]],
|
||||
remaining_count_dict: dict[str, int],
|
||||
@ -143,6 +143,7 @@ class KVOutputAggregator:
|
||||
finished_sending = set[str]()
|
||||
finished_recving = set[str]()
|
||||
for output in outputs:
|
||||
output = output.kv_connector_output
|
||||
update_finished_set(output.finished_sending,
|
||||
self._send_remaining_count, finished_sending)
|
||||
update_finished_set(output.finished_recving,
|
||||
@ -151,13 +152,10 @@ class KVOutputAggregator:
|
||||
# select output of the worker specified by output_rank
|
||||
output = outputs[output_rank]
|
||||
|
||||
# set the aggregated finished_sending / finished_recving
|
||||
# if output.finished_sending/recving is not empty, but the other ranks
|
||||
# still have unfinished send/recv, we want to set the aggregated
|
||||
# finished_sending/recving to None until all ranks have finished
|
||||
# send/recv
|
||||
output.finished_sending = finished_sending if finished_sending else None
|
||||
output.finished_recving = finished_recving if finished_recving else None
|
||||
output.kv_connector_output = KVConnectorOutput(
|
||||
finished_sending=finished_sending or None,
|
||||
finished_recving=finished_recving or None,
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
@ -10,7 +10,7 @@ from collections.abc import Mapping
|
||||
from collections.abc import Sequence as GenericSequence
|
||||
from dataclasses import dataclass, field
|
||||
from functools import reduce
|
||||
from typing import Any, Callable, Optional, Union
|
||||
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
|
||||
|
||||
import msgspec
|
||||
import torch
|
||||
@ -21,6 +21,10 @@ from vllm.multimodal import MultiModalKwargs, MultiModalPlaceholderDict
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sampling_params import RequestOutputKind, SamplingParams
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.worker.kv_connector_model_runner_mixin import (
|
||||
KVConnectorOutput)
|
||||
|
||||
VLLM_TOKEN_ID_ARRAY_TYPE = "l"
|
||||
|
||||
VLLM_INVALID_TOKEN_ID = -1
|
||||
@ -1159,14 +1163,11 @@ class IntermediateTensors:
|
||||
states and residuals to be sent to the next stage. This data structure
|
||||
contains the hidden states and residuals for a request.
|
||||
|
||||
Each stage also needs to handle its own finished_sending and
|
||||
finished_recving in case of kv transfer.
|
||||
Each stage also needs to handle its own kv_connector_output.
|
||||
"""
|
||||
|
||||
tensors: dict[str, torch.Tensor]
|
||||
# [req_ids]
|
||||
finished_sending: Optional[set[str]] = None
|
||||
finished_recving: Optional[set[str]] = None
|
||||
kv_connector_output: Optional["KVConnectorOutput"]
|
||||
|
||||
def __init__(self, tensors):
|
||||
# manually define this function, so that
|
||||
|
||||
@ -30,7 +30,7 @@ from vllm.v1.engine import (EngineCoreEventType, EngineCoreOutput,
|
||||
EngineCoreOutputs)
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.metrics.stats import SchedulerStats
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput
|
||||
from vllm.v1.request import Request, RequestStatus
|
||||
from vllm.v1.spec_decode.metrics import SpecDecodingStats
|
||||
from vllm.v1.structured_output import StructuredOutputManager
|
||||
@ -884,7 +884,9 @@ class Scheduler(SchedulerInterface):
|
||||
self.waiting.remove_requests(stopped_preempted_reqs)
|
||||
|
||||
# KV Connector: update state for finished KV Transfers.
|
||||
self._update_from_kv_xfer_finished(model_runner_output)
|
||||
if model_runner_output.kv_connector_output:
|
||||
self._update_from_kv_xfer_finished(
|
||||
model_runner_output.kv_connector_output)
|
||||
|
||||
# Create EngineCoreOutputs for all clients that have requests with
|
||||
# outputs in this step.
|
||||
@ -1128,7 +1130,7 @@ class Scheduler(SchedulerInterface):
|
||||
return True
|
||||
|
||||
def _update_from_kv_xfer_finished(self,
|
||||
model_runner_output: ModelRunnerOutput):
|
||||
kv_connector_output: KVConnectorOutput):
|
||||
"""
|
||||
KV Connector: update the scheduler state based on the output.
|
||||
|
||||
@ -1139,9 +1141,9 @@ class Scheduler(SchedulerInterface):
|
||||
scheduler the request during the next step.
|
||||
"""
|
||||
# KV Connector:: update recv and send status from last step.
|
||||
for req_id in (model_runner_output.finished_recving or ()):
|
||||
for req_id in (kv_connector_output.finished_recving or ()):
|
||||
logger.debug("Finished recving KV transfer for request %s", req_id)
|
||||
self.finished_recving_kv_req_ids.add(req_id)
|
||||
for req_id in (model_runner_output.finished_sending or ()):
|
||||
for req_id in (kv_connector_output.finished_sending or ()):
|
||||
logger.debug("Finished sending KV transfer for request %s", req_id)
|
||||
self._free_blocks(self.requests[req_id])
|
||||
|
||||
@ -71,6 +71,13 @@ class SamplerOutput:
|
||||
logprobs_tensors: Optional[LogprobsTensors]
|
||||
|
||||
|
||||
@dataclass
|
||||
class KVConnectorOutput:
|
||||
# [req_ids]
|
||||
finished_sending: Optional[set[str]] = None
|
||||
finished_recving: Optional[set[str]] = None
|
||||
|
||||
|
||||
# ModelRunnerOutput is serialized and sent to the scheduler process.
|
||||
# This is expensive for torch.Tensor so prefer to use list instead.
|
||||
@dataclass
|
||||
@ -104,9 +111,7 @@ class ModelRunnerOutput:
|
||||
# [num_reqs, hidden_size]
|
||||
pooler_output: list[Optional[torch.Tensor]]
|
||||
|
||||
# [req_ids]
|
||||
finished_sending: Optional[set[str]] = None
|
||||
finished_recving: Optional[set[str]] = None
|
||||
kv_connector_output: Optional[KVConnectorOutput] = None
|
||||
|
||||
# req_id -> num_nans_in_logits
|
||||
num_nans_in_logits: Optional[dict[str, int]] = None
|
||||
@ -119,6 +124,4 @@ EMPTY_MODEL_RUNNER_OUTPUT = ModelRunnerOutput(req_ids=[],
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[],
|
||||
finished_sending=None,
|
||||
finished_recving=None,
|
||||
num_nans_in_logits=None)
|
||||
|
||||
@ -69,7 +69,7 @@ from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
||||
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
|
||||
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
||||
from vllm.v1.worker.kv_connector_model_runner_mixin import (
|
||||
KVConnectorModelRunnerMixin)
|
||||
KVConnectorModelRunnerMixin, KVConnectorOutput)
|
||||
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
|
||||
|
||||
from ..sample.logits_processor import LogitsProcessorManager
|
||||
@ -1423,8 +1423,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
hidden_states: torch.Tensor,
|
||||
num_scheduled_tokens: int,
|
||||
num_scheduled_tokens_np: np.ndarray,
|
||||
finished_sending: Optional[set[str]],
|
||||
finished_recving: Optional[set[str]],
|
||||
kv_connector_output: Optional[KVConnectorOutput],
|
||||
) -> ModelRunnerOutput:
|
||||
assert self.input_batch.num_reqs ==\
|
||||
len(self.input_batch.pooling_params), \
|
||||
@ -1459,8 +1458,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=pooler_output,
|
||||
finished_sending=finished_sending,
|
||||
finished_recving=finished_recving,
|
||||
kv_connector_output=kv_connector_output,
|
||||
)
|
||||
|
||||
@torch.inference_mode()
|
||||
@ -1564,8 +1562,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
num_tokens=num_input_tokens,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
skip_cuda_graphs=skip_cuda_graphs,
|
||||
):
|
||||
self.maybe_setup_kv_connector(scheduler_output)
|
||||
), self.maybe_get_kv_connector_output(
|
||||
scheduler_output) as kv_connector_output:
|
||||
|
||||
model_output = self.model(
|
||||
input_ids=input_ids,
|
||||
@ -1578,10 +1576,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
),
|
||||
)
|
||||
|
||||
self.maybe_wait_for_kv_save()
|
||||
finished_sending, finished_recving = (
|
||||
self.get_finished_kv_transfers(scheduler_output))
|
||||
|
||||
if self.use_aux_hidden_state_outputs:
|
||||
hidden_states, aux_hidden_states = model_output
|
||||
else:
|
||||
@ -1597,20 +1591,17 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
== "external_launcher" and len(get_pp_group().ranks) > 0
|
||||
if not get_pp_group().is_last_rank:
|
||||
# For mid-pipeline stages, return the hidden states.
|
||||
if not broadcast_pp_output:
|
||||
if finished_sending or finished_recving:
|
||||
hidden_states.finished_sending = finished_sending
|
||||
hidden_states.finished_recving = finished_recving
|
||||
return hidden_states
|
||||
assert isinstance(hidden_states, IntermediateTensors)
|
||||
if not broadcast_pp_output:
|
||||
hidden_states.kv_connector_output = kv_connector_output
|
||||
return hidden_states
|
||||
get_pp_group().send_tensor_dict(hidden_states.tensors,
|
||||
all_gather_group=get_tp_group())
|
||||
logits = None
|
||||
else:
|
||||
if self.input_batch.pooling_params:
|
||||
return self._pool(hidden_states, num_scheduled_tokens,
|
||||
num_scheduled_tokens_np, finished_sending,
|
||||
finished_recving)
|
||||
num_scheduled_tokens_np, kv_connector_output)
|
||||
|
||||
sample_hidden_states = hidden_states[logits_indices]
|
||||
logits = self.model.compute_logits(sample_hidden_states, None)
|
||||
@ -1760,8 +1751,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
logprobs=logprobs_lists,
|
||||
prompt_logprobs_dict=prompt_logprobs_dict,
|
||||
pooler_output=[],
|
||||
finished_sending=finished_sending,
|
||||
finished_recving=finished_recving,
|
||||
kv_connector_output=kv_connector_output,
|
||||
num_nans_in_logits=num_nans_in_logits,
|
||||
)
|
||||
|
||||
|
||||
@ -16,8 +16,7 @@ from vllm.config import VllmConfig
|
||||
from vllm.distributed import (ensure_model_parallel_initialized,
|
||||
init_distributed_environment,
|
||||
set_custom_all_reduce)
|
||||
from vllm.distributed.kv_transfer import (ensure_kv_transfer_initialized,
|
||||
has_kv_transfer_group)
|
||||
from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized
|
||||
from vllm.distributed.parallel_state import get_pp_group, get_tp_group
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
@ -369,17 +368,20 @@ class Worker(WorkerBase):
|
||||
assert isinstance(output, IntermediateTensors)
|
||||
get_pp_group().send_tensor_dict(output.tensors,
|
||||
all_gather_group=get_tp_group())
|
||||
if not has_kv_transfer_group():
|
||||
|
||||
kv_connector_output = output.kv_connector_output
|
||||
if not kv_connector_output:
|
||||
return None
|
||||
|
||||
# In case of PP with kv transfer, we need to pass through the
|
||||
# finished_sending and finished_recving buffers.
|
||||
new_output = EMPTY_MODEL_RUNNER_OUTPUT
|
||||
if output.finished_sending or output.finished_recving:
|
||||
new_output = copy.copy(new_output)
|
||||
new_output.finished_sending = output.finished_sending
|
||||
new_output.finished_recving = output.finished_recving
|
||||
output = new_output
|
||||
# kv_connector_output
|
||||
if (not kv_connector_output.finished_sending
|
||||
and not kv_connector_output.finished_recving):
|
||||
return EMPTY_MODEL_RUNNER_OUTPUT
|
||||
|
||||
output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT)
|
||||
output.kv_connector_output = kv_connector_output
|
||||
return output
|
||||
|
||||
assert isinstance(output, ModelRunnerOutput)
|
||||
return output
|
||||
|
||||
@ -4,6 +4,8 @@
|
||||
Define KV connector functionality mixin for model runners.
|
||||
"""
|
||||
import copy
|
||||
from contextlib import AbstractContextManager, contextmanager, nullcontext
|
||||
from typing import Generator # noqa: UP035
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
@ -12,7 +14,8 @@ from vllm.distributed.kv_transfer import (get_kv_transfer_group,
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1
|
||||
from vllm.forward_context import get_forward_context, set_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, ModelRunnerOutput
|
||||
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, KVConnectorOutput,
|
||||
ModelRunnerOutput)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
@ -53,18 +56,60 @@ class KVConnectorModelRunnerMixin:
|
||||
scheduler_output.finished_req_ids)
|
||||
return None, None
|
||||
|
||||
def kv_connector_no_forward(self, scheduler_output: "SchedulerOutput",
|
||||
@staticmethod
|
||||
def kv_connector_no_forward(scheduler_output: "SchedulerOutput",
|
||||
vllm_config: VllmConfig) -> ModelRunnerOutput:
|
||||
# KV send/recv even if no work to do.
|
||||
with set_forward_context(None, vllm_config):
|
||||
self.maybe_setup_kv_connector(scheduler_output)
|
||||
finished_sending, finished_recving = (
|
||||
self.get_finished_kv_transfers(scheduler_output))
|
||||
with set_forward_context(
|
||||
None, vllm_config
|
||||
), KVConnectorModelRunnerMixin._get_kv_connector_output(
|
||||
scheduler_output, wait_for_save=False) as kv_connector_output:
|
||||
pass
|
||||
|
||||
if not finished_sending and not finished_recving:
|
||||
if (not kv_connector_output.finished_sending
|
||||
and not kv_connector_output.finished_recving):
|
||||
return EMPTY_MODEL_RUNNER_OUTPUT
|
||||
|
||||
output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT)
|
||||
output.finished_sending = finished_sending
|
||||
output.finished_recving = finished_recving
|
||||
output.kv_connector_output = kv_connector_output
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def maybe_get_kv_connector_output(
|
||||
scheduler_output: "SchedulerOutput"
|
||||
) -> AbstractContextManager[Optional[KVConnectorOutput]]:
|
||||
return KVConnectorModelRunnerMixin._get_kv_connector_output(
|
||||
scheduler_output) if has_kv_transfer_group() else nullcontext()
|
||||
|
||||
# This context manager must be used within an active forward context.
|
||||
# It encapsulates the entire KV conector lifecycle within execute_model
|
||||
@staticmethod
|
||||
@contextmanager
|
||||
def _get_kv_connector_output(
|
||||
scheduler_output: "SchedulerOutput",
|
||||
wait_for_save: bool = True
|
||||
) -> Generator[KVConnectorOutput, None, None]:
|
||||
output = KVConnectorOutput()
|
||||
|
||||
# Update KVConnector with the KVConnector metadata forward().
|
||||
kv_connector = get_kv_transfer_group()
|
||||
assert isinstance(kv_connector, KVConnectorBase_V1)
|
||||
assert scheduler_output.kv_connector_metadata is not None
|
||||
kv_connector.bind_connector_metadata(
|
||||
scheduler_output.kv_connector_metadata)
|
||||
|
||||
# Background KV cache transfers happen here.
|
||||
# These transfers are designed to be async and the requests
|
||||
# involved may be disjoint from the running requests.
|
||||
# Do this here to save a collective_rpc.
|
||||
kv_connector.start_load_kv(get_forward_context())
|
||||
try:
|
||||
yield output
|
||||
finally:
|
||||
if wait_for_save:
|
||||
kv_connector.wait_for_save()
|
||||
|
||||
output.finished_sending, output.finished_recving = (
|
||||
kv_connector.get_finished(scheduler_output.finished_req_ids))
|
||||
|
||||
kv_connector.clear_connector_metadata()
|
||||
|
||||
@ -51,7 +51,7 @@ from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsLists,
|
||||
from vllm.v1.sample.tpu.metadata import TPUSupportedSamplingMetadata
|
||||
from vllm.v1.sample.tpu.sampler import Sampler as TPUSampler
|
||||
from vllm.v1.worker.kv_connector_model_runner_mixin import (
|
||||
KVConnectorModelRunnerMixin)
|
||||
KVConnectorModelRunnerMixin, KVConnectorOutput)
|
||||
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
|
||||
from vllm.v1.worker.tpu_input_batch import CachedRequestState, InputBatch
|
||||
|
||||
@ -1175,9 +1175,10 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
logprobs=logprobs_lists,
|
||||
prompt_logprobs_dict=prompt_logprobs_dict,
|
||||
pooler_output=[],
|
||||
finished_sending=finished_sending,
|
||||
finished_recving=finished_recving,
|
||||
)
|
||||
kv_connector_output=KVConnectorOutput(
|
||||
finished_sending=finished_sending,
|
||||
finished_recving=finished_recving,
|
||||
))
|
||||
|
||||
# Check there are no new graphs compiled - all the graphs should be
|
||||
# captured and compiled during warm up.
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user