mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-14 05:25:20 +08:00
[V1] [P/D] Refactor KV Connector Path (#21980)
Signed-off-by: David Ben-David <davidb@pliops.com> Co-authored-by: David Ben-David <davidb@pliops.com>
This commit is contained in:
parent
24d1dffbeb
commit
aefeea0fde
@ -4,7 +4,7 @@ from concurrent.futures import Future
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator
|
from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator
|
||||||
from vllm.v1.outputs import ModelRunnerOutput
|
from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput
|
||||||
|
|
||||||
|
|
||||||
class DummyModelRunnerOutput(ModelRunnerOutput):
|
class DummyModelRunnerOutput(ModelRunnerOutput):
|
||||||
@ -12,8 +12,16 @@ class DummyModelRunnerOutput(ModelRunnerOutput):
|
|||||||
def __init__(self,
|
def __init__(self,
|
||||||
finished_sending: Optional[set[str]] = None,
|
finished_sending: Optional[set[str]] = None,
|
||||||
finished_recving: Optional[set[str]] = None):
|
finished_recving: Optional[set[str]] = None):
|
||||||
self.finished_sending = finished_sending
|
self.kv_connector_output = KVConnectorOutput(
|
||||||
self.finished_recving = finished_recving
|
finished_sending=finished_sending,
|
||||||
|
finished_recving=finished_recving,
|
||||||
|
)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return (
|
||||||
|
f"DummyModelRunnerOutput("
|
||||||
|
f"finished_sending={self.kv_connector_output.finished_sending},"
|
||||||
|
f"finished_recving={self.kv_connector_output.finished_recving})")
|
||||||
|
|
||||||
|
|
||||||
def test_aggregate_workers_output():
|
def test_aggregate_workers_output():
|
||||||
@ -27,6 +35,7 @@ def test_aggregate_workers_output():
|
|||||||
aggregated = aggregator.aggregate([output1, output2])
|
aggregated = aggregator.aggregate([output1, output2])
|
||||||
|
|
||||||
assert aggregated is output1
|
assert aggregated is output1
|
||||||
|
aggregated = aggregated.kv_connector_output
|
||||||
assert aggregated.finished_sending is None
|
assert aggregated.finished_sending is None
|
||||||
assert aggregated.finished_recving is None
|
assert aggregated.finished_recving is None
|
||||||
|
|
||||||
@ -38,6 +47,7 @@ def test_aggregate_workers_output():
|
|||||||
aggregated = aggregator.aggregate([output1, output2])
|
aggregated = aggregator.aggregate([output1, output2])
|
||||||
|
|
||||||
assert aggregated is output1
|
assert aggregated is output1
|
||||||
|
aggregated = aggregated.kv_connector_output
|
||||||
assert aggregated.finished_sending == {'req1'}
|
assert aggregated.finished_sending == {'req1'}
|
||||||
assert aggregated.finished_recving is None
|
assert aggregated.finished_recving is None
|
||||||
|
|
||||||
@ -49,6 +59,7 @@ def test_aggregate_workers_output():
|
|||||||
aggregated = aggregator.aggregate([output1, output2])
|
aggregated = aggregator.aggregate([output1, output2])
|
||||||
|
|
||||||
assert aggregated is output1
|
assert aggregated is output1
|
||||||
|
aggregated = aggregated.kv_connector_output
|
||||||
assert aggregated.finished_sending is None
|
assert aggregated.finished_sending is None
|
||||||
assert aggregated.finished_recving == {'req2'}
|
assert aggregated.finished_recving == {'req2'}
|
||||||
|
|
||||||
@ -70,6 +81,7 @@ def test_async_aggregate_workers_output():
|
|||||||
assert result_future.done()
|
assert result_future.done()
|
||||||
aggregated = result_future.result()
|
aggregated = result_future.result()
|
||||||
assert aggregated is output1
|
assert aggregated is output1
|
||||||
|
aggregated = aggregated.kv_connector_output
|
||||||
assert aggregated.finished_sending is None
|
assert aggregated.finished_sending is None
|
||||||
assert aggregated.finished_recving is None
|
assert aggregated.finished_recving is None
|
||||||
|
|
||||||
@ -87,6 +99,7 @@ def test_async_aggregate_workers_output():
|
|||||||
assert result_future.done()
|
assert result_future.done()
|
||||||
aggregated = result_future.result()
|
aggregated = result_future.result()
|
||||||
assert aggregated is output1
|
assert aggregated is output1
|
||||||
|
aggregated = aggregated.kv_connector_output
|
||||||
assert aggregated.finished_sending == {'req1'}
|
assert aggregated.finished_sending == {'req1'}
|
||||||
assert aggregated.finished_recving is None
|
assert aggregated.finished_recving is None
|
||||||
|
|
||||||
@ -104,5 +117,6 @@ def test_async_aggregate_workers_output():
|
|||||||
assert result_future.done()
|
assert result_future.done()
|
||||||
aggregated = result_future.result()
|
aggregated = result_future.result()
|
||||||
assert aggregated is output1
|
assert aggregated is output1
|
||||||
|
aggregated = aggregated.kv_connector_output
|
||||||
assert aggregated.finished_sending is None
|
assert aggregated.finished_sending is None
|
||||||
assert aggregated.finished_recving == {'req2'}
|
assert aggregated.finished_recving == {'req2'}
|
||||||
|
|||||||
@ -2,7 +2,7 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
import copy
|
import copy
|
||||||
|
|
||||||
from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT
|
from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, KVConnectorOutput
|
||||||
from vllm.v1.request import FinishReason, RequestStatus
|
from vllm.v1.request import FinishReason, RequestStatus
|
||||||
|
|
||||||
from .utils import (assert_scheduler_empty, create_model_runner_output,
|
from .utils import (assert_scheduler_empty, create_model_runner_output,
|
||||||
@ -86,7 +86,8 @@ def test_basic_lifecycle():
|
|||||||
|
|
||||||
# (3b): execute_model()
|
# (3b): execute_model()
|
||||||
model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
|
model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
|
||||||
model_runner_output.finished_sending = [request_id]
|
model_runner_output.kv_connector_output = KVConnectorOutput(
|
||||||
|
finished_sending=[request_id])
|
||||||
|
|
||||||
# (3c): update_from_output()
|
# (3c): update_from_output()
|
||||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||||
@ -176,7 +177,8 @@ def test_prefix_cache_lifecycle():
|
|||||||
scheduler_output = scheduler.schedule()
|
scheduler_output = scheduler.schedule()
|
||||||
scheduler.schedule()
|
scheduler.schedule()
|
||||||
model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
|
model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
|
||||||
model_runner_output.finished_sending = [request_remote.request_id]
|
model_runner_output.kv_connector_output = KVConnectorOutput(
|
||||||
|
finished_sending=[request_remote.request_id])
|
||||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||||
_ = scheduler.schedule()
|
_ = scheduler.schedule()
|
||||||
assert_scheduler_empty(scheduler)
|
assert_scheduler_empty(scheduler)
|
||||||
|
|||||||
@ -2,7 +2,7 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
import copy
|
import copy
|
||||||
|
|
||||||
from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT
|
from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, KVConnectorOutput
|
||||||
from vllm.v1.request import FinishReason, RequestStatus
|
from vllm.v1.request import FinishReason, RequestStatus
|
||||||
|
|
||||||
from .utils import (assert_scheduler_empty, create_model_runner_output,
|
from .utils import (assert_scheduler_empty, create_model_runner_output,
|
||||||
@ -72,7 +72,8 @@ def test_basic_lifecycle():
|
|||||||
|
|
||||||
# (2b): forward(): request finishes recv.
|
# (2b): forward(): request finishes recv.
|
||||||
model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
|
model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
|
||||||
model_runner_output.finished_recving = [request_id]
|
model_runner_output.kv_connector_output = KVConnectorOutput(
|
||||||
|
finished_recving=[request_id])
|
||||||
|
|
||||||
# (2c): update_from_output():
|
# (2c): update_from_output():
|
||||||
engine_core_outputs = scheduler.update_from_output(scheduler_output,
|
engine_core_outputs = scheduler.update_from_output(scheduler_output,
|
||||||
@ -309,7 +310,8 @@ def test_full_block_prompt():
|
|||||||
# # STEP (2): Recv.
|
# # STEP (2): Recv.
|
||||||
scheduler_output = scheduler.schedule()
|
scheduler_output = scheduler.schedule()
|
||||||
model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
|
model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
|
||||||
model_runner_output.finished_recving = [request_id]
|
model_runner_output.kv_connector_output = KVConnectorOutput(
|
||||||
|
finished_recving=[request_id])
|
||||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||||
assert len(scheduler.waiting) == 1
|
assert len(scheduler.waiting) == 1
|
||||||
assert (request_id in scheduler.finished_recving_kv_req_ids)
|
assert (request_id in scheduler.finished_recving_kv_req_ids)
|
||||||
|
|||||||
@ -17,7 +17,7 @@ from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
|||||||
from vllm.v1.core.sched.scheduler import Scheduler
|
from vllm.v1.core.sched.scheduler import Scheduler
|
||||||
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
||||||
KVCacheGroupSpec)
|
KVCacheGroupSpec)
|
||||||
from vllm.v1.outputs import ModelRunnerOutput
|
from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput
|
||||||
from vllm.v1.request import Request
|
from vllm.v1.request import Request
|
||||||
from vllm.v1.structured_output import StructuredOutputManager
|
from vllm.v1.structured_output import StructuredOutputManager
|
||||||
|
|
||||||
@ -188,8 +188,10 @@ def create_model_runner_output(
|
|||||||
logprobs=None,
|
logprobs=None,
|
||||||
prompt_logprobs_dict={},
|
prompt_logprobs_dict={},
|
||||||
pooler_output=None,
|
pooler_output=None,
|
||||||
finished_sending=finished_sending,
|
kv_connector_output=KVConnectorOutput(
|
||||||
finished_recving=finished_recving,
|
finished_sending=finished_sending,
|
||||||
|
finished_recving=finished_recving,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -16,7 +16,7 @@ from vllm.config import VllmConfig, get_current_vllm_config
|
|||||||
from vllm.distributed.kv_transfer.kv_connector.factory import (
|
from vllm.distributed.kv_transfer.kv_connector.factory import (
|
||||||
KVConnectorFactory)
|
KVConnectorFactory)
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.v1.outputs import ModelRunnerOutput
|
from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -129,7 +129,7 @@ class KVOutputAggregator:
|
|||||||
def aggregate(self,
|
def aggregate(self,
|
||||||
outputs: list[ModelRunnerOutput],
|
outputs: list[ModelRunnerOutput],
|
||||||
output_rank: int = 0) -> ModelRunnerOutput:
|
output_rank: int = 0) -> ModelRunnerOutput:
|
||||||
# aggregate finished_sending, finished_recving from all workers
|
# aggregate kv_connector_output from all workers
|
||||||
|
|
||||||
def update_finished_set(req_ids: Optional[set[str]],
|
def update_finished_set(req_ids: Optional[set[str]],
|
||||||
remaining_count_dict: dict[str, int],
|
remaining_count_dict: dict[str, int],
|
||||||
@ -143,6 +143,7 @@ class KVOutputAggregator:
|
|||||||
finished_sending = set[str]()
|
finished_sending = set[str]()
|
||||||
finished_recving = set[str]()
|
finished_recving = set[str]()
|
||||||
for output in outputs:
|
for output in outputs:
|
||||||
|
output = output.kv_connector_output
|
||||||
update_finished_set(output.finished_sending,
|
update_finished_set(output.finished_sending,
|
||||||
self._send_remaining_count, finished_sending)
|
self._send_remaining_count, finished_sending)
|
||||||
update_finished_set(output.finished_recving,
|
update_finished_set(output.finished_recving,
|
||||||
@ -151,13 +152,10 @@ class KVOutputAggregator:
|
|||||||
# select output of the worker specified by output_rank
|
# select output of the worker specified by output_rank
|
||||||
output = outputs[output_rank]
|
output = outputs[output_rank]
|
||||||
|
|
||||||
# set the aggregated finished_sending / finished_recving
|
output.kv_connector_output = KVConnectorOutput(
|
||||||
# if output.finished_sending/recving is not empty, but the other ranks
|
finished_sending=finished_sending or None,
|
||||||
# still have unfinished send/recv, we want to set the aggregated
|
finished_recving=finished_recving or None,
|
||||||
# finished_sending/recving to None until all ranks have finished
|
)
|
||||||
# send/recv
|
|
||||||
output.finished_sending = finished_sending if finished_sending else None
|
|
||||||
output.finished_recving = finished_recving if finished_recving else None
|
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|||||||
@ -10,7 +10,7 @@ from collections.abc import Mapping
|
|||||||
from collections.abc import Sequence as GenericSequence
|
from collections.abc import Sequence as GenericSequence
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from functools import reduce
|
from functools import reduce
|
||||||
from typing import Any, Callable, Optional, Union
|
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
|
||||||
|
|
||||||
import msgspec
|
import msgspec
|
||||||
import torch
|
import torch
|
||||||
@ -21,6 +21,10 @@ from vllm.multimodal import MultiModalKwargs, MultiModalPlaceholderDict
|
|||||||
from vllm.pooling_params import PoolingParams
|
from vllm.pooling_params import PoolingParams
|
||||||
from vllm.sampling_params import RequestOutputKind, SamplingParams
|
from vllm.sampling_params import RequestOutputKind, SamplingParams
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from vllm.v1.worker.kv_connector_model_runner_mixin import (
|
||||||
|
KVConnectorOutput)
|
||||||
|
|
||||||
VLLM_TOKEN_ID_ARRAY_TYPE = "l"
|
VLLM_TOKEN_ID_ARRAY_TYPE = "l"
|
||||||
|
|
||||||
VLLM_INVALID_TOKEN_ID = -1
|
VLLM_INVALID_TOKEN_ID = -1
|
||||||
@ -1159,14 +1163,11 @@ class IntermediateTensors:
|
|||||||
states and residuals to be sent to the next stage. This data structure
|
states and residuals to be sent to the next stage. This data structure
|
||||||
contains the hidden states and residuals for a request.
|
contains the hidden states and residuals for a request.
|
||||||
|
|
||||||
Each stage also needs to handle its own finished_sending and
|
Each stage also needs to handle its own kv_connector_output.
|
||||||
finished_recving in case of kv transfer.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
tensors: dict[str, torch.Tensor]
|
tensors: dict[str, torch.Tensor]
|
||||||
# [req_ids]
|
kv_connector_output: Optional["KVConnectorOutput"]
|
||||||
finished_sending: Optional[set[str]] = None
|
|
||||||
finished_recving: Optional[set[str]] = None
|
|
||||||
|
|
||||||
def __init__(self, tensors):
|
def __init__(self, tensors):
|
||||||
# manually define this function, so that
|
# manually define this function, so that
|
||||||
|
|||||||
@ -30,7 +30,7 @@ from vllm.v1.engine import (EngineCoreEventType, EngineCoreOutput,
|
|||||||
EngineCoreOutputs)
|
EngineCoreOutputs)
|
||||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||||
from vllm.v1.metrics.stats import SchedulerStats
|
from vllm.v1.metrics.stats import SchedulerStats
|
||||||
from vllm.v1.outputs import ModelRunnerOutput
|
from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput
|
||||||
from vllm.v1.request import Request, RequestStatus
|
from vllm.v1.request import Request, RequestStatus
|
||||||
from vllm.v1.spec_decode.metrics import SpecDecodingStats
|
from vllm.v1.spec_decode.metrics import SpecDecodingStats
|
||||||
from vllm.v1.structured_output import StructuredOutputManager
|
from vllm.v1.structured_output import StructuredOutputManager
|
||||||
@ -884,7 +884,9 @@ class Scheduler(SchedulerInterface):
|
|||||||
self.waiting.remove_requests(stopped_preempted_reqs)
|
self.waiting.remove_requests(stopped_preempted_reqs)
|
||||||
|
|
||||||
# KV Connector: update state for finished KV Transfers.
|
# KV Connector: update state for finished KV Transfers.
|
||||||
self._update_from_kv_xfer_finished(model_runner_output)
|
if model_runner_output.kv_connector_output:
|
||||||
|
self._update_from_kv_xfer_finished(
|
||||||
|
model_runner_output.kv_connector_output)
|
||||||
|
|
||||||
# Create EngineCoreOutputs for all clients that have requests with
|
# Create EngineCoreOutputs for all clients that have requests with
|
||||||
# outputs in this step.
|
# outputs in this step.
|
||||||
@ -1128,7 +1130,7 @@ class Scheduler(SchedulerInterface):
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
def _update_from_kv_xfer_finished(self,
|
def _update_from_kv_xfer_finished(self,
|
||||||
model_runner_output: ModelRunnerOutput):
|
kv_connector_output: KVConnectorOutput):
|
||||||
"""
|
"""
|
||||||
KV Connector: update the scheduler state based on the output.
|
KV Connector: update the scheduler state based on the output.
|
||||||
|
|
||||||
@ -1139,9 +1141,9 @@ class Scheduler(SchedulerInterface):
|
|||||||
scheduler the request during the next step.
|
scheduler the request during the next step.
|
||||||
"""
|
"""
|
||||||
# KV Connector:: update recv and send status from last step.
|
# KV Connector:: update recv and send status from last step.
|
||||||
for req_id in (model_runner_output.finished_recving or ()):
|
for req_id in (kv_connector_output.finished_recving or ()):
|
||||||
logger.debug("Finished recving KV transfer for request %s", req_id)
|
logger.debug("Finished recving KV transfer for request %s", req_id)
|
||||||
self.finished_recving_kv_req_ids.add(req_id)
|
self.finished_recving_kv_req_ids.add(req_id)
|
||||||
for req_id in (model_runner_output.finished_sending or ()):
|
for req_id in (kv_connector_output.finished_sending or ()):
|
||||||
logger.debug("Finished sending KV transfer for request %s", req_id)
|
logger.debug("Finished sending KV transfer for request %s", req_id)
|
||||||
self._free_blocks(self.requests[req_id])
|
self._free_blocks(self.requests[req_id])
|
||||||
|
|||||||
@ -71,6 +71,13 @@ class SamplerOutput:
|
|||||||
logprobs_tensors: Optional[LogprobsTensors]
|
logprobs_tensors: Optional[LogprobsTensors]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class KVConnectorOutput:
|
||||||
|
# [req_ids]
|
||||||
|
finished_sending: Optional[set[str]] = None
|
||||||
|
finished_recving: Optional[set[str]] = None
|
||||||
|
|
||||||
|
|
||||||
# ModelRunnerOutput is serialized and sent to the scheduler process.
|
# ModelRunnerOutput is serialized and sent to the scheduler process.
|
||||||
# This is expensive for torch.Tensor so prefer to use list instead.
|
# This is expensive for torch.Tensor so prefer to use list instead.
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -104,9 +111,7 @@ class ModelRunnerOutput:
|
|||||||
# [num_reqs, hidden_size]
|
# [num_reqs, hidden_size]
|
||||||
pooler_output: list[Optional[torch.Tensor]]
|
pooler_output: list[Optional[torch.Tensor]]
|
||||||
|
|
||||||
# [req_ids]
|
kv_connector_output: Optional[KVConnectorOutput] = None
|
||||||
finished_sending: Optional[set[str]] = None
|
|
||||||
finished_recving: Optional[set[str]] = None
|
|
||||||
|
|
||||||
# req_id -> num_nans_in_logits
|
# req_id -> num_nans_in_logits
|
||||||
num_nans_in_logits: Optional[dict[str, int]] = None
|
num_nans_in_logits: Optional[dict[str, int]] = None
|
||||||
@ -119,6 +124,4 @@ EMPTY_MODEL_RUNNER_OUTPUT = ModelRunnerOutput(req_ids=[],
|
|||||||
logprobs=None,
|
logprobs=None,
|
||||||
prompt_logprobs_dict={},
|
prompt_logprobs_dict={},
|
||||||
pooler_output=[],
|
pooler_output=[],
|
||||||
finished_sending=None,
|
|
||||||
finished_recving=None,
|
|
||||||
num_nans_in_logits=None)
|
num_nans_in_logits=None)
|
||||||
|
|||||||
@ -69,7 +69,7 @@ from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
|||||||
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
|
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
|
||||||
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
||||||
from vllm.v1.worker.kv_connector_model_runner_mixin import (
|
from vllm.v1.worker.kv_connector_model_runner_mixin import (
|
||||||
KVConnectorModelRunnerMixin)
|
KVConnectorModelRunnerMixin, KVConnectorOutput)
|
||||||
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
|
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
|
||||||
|
|
||||||
from ..sample.logits_processor import LogitsProcessorManager
|
from ..sample.logits_processor import LogitsProcessorManager
|
||||||
@ -1423,8 +1423,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
num_scheduled_tokens: int,
|
num_scheduled_tokens: int,
|
||||||
num_scheduled_tokens_np: np.ndarray,
|
num_scheduled_tokens_np: np.ndarray,
|
||||||
finished_sending: Optional[set[str]],
|
kv_connector_output: Optional[KVConnectorOutput],
|
||||||
finished_recving: Optional[set[str]],
|
|
||||||
) -> ModelRunnerOutput:
|
) -> ModelRunnerOutput:
|
||||||
assert self.input_batch.num_reqs ==\
|
assert self.input_batch.num_reqs ==\
|
||||||
len(self.input_batch.pooling_params), \
|
len(self.input_batch.pooling_params), \
|
||||||
@ -1459,8 +1458,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
logprobs=None,
|
logprobs=None,
|
||||||
prompt_logprobs_dict={},
|
prompt_logprobs_dict={},
|
||||||
pooler_output=pooler_output,
|
pooler_output=pooler_output,
|
||||||
finished_sending=finished_sending,
|
kv_connector_output=kv_connector_output,
|
||||||
finished_recving=finished_recving,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
@ -1564,8 +1562,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
num_tokens=num_input_tokens,
|
num_tokens=num_input_tokens,
|
||||||
num_tokens_across_dp=num_tokens_across_dp,
|
num_tokens_across_dp=num_tokens_across_dp,
|
||||||
skip_cuda_graphs=skip_cuda_graphs,
|
skip_cuda_graphs=skip_cuda_graphs,
|
||||||
):
|
), self.maybe_get_kv_connector_output(
|
||||||
self.maybe_setup_kv_connector(scheduler_output)
|
scheduler_output) as kv_connector_output:
|
||||||
|
|
||||||
model_output = self.model(
|
model_output = self.model(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
@ -1578,10 +1576,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
self.maybe_wait_for_kv_save()
|
|
||||||
finished_sending, finished_recving = (
|
|
||||||
self.get_finished_kv_transfers(scheduler_output))
|
|
||||||
|
|
||||||
if self.use_aux_hidden_state_outputs:
|
if self.use_aux_hidden_state_outputs:
|
||||||
hidden_states, aux_hidden_states = model_output
|
hidden_states, aux_hidden_states = model_output
|
||||||
else:
|
else:
|
||||||
@ -1597,20 +1591,17 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
== "external_launcher" and len(get_pp_group().ranks) > 0
|
== "external_launcher" and len(get_pp_group().ranks) > 0
|
||||||
if not get_pp_group().is_last_rank:
|
if not get_pp_group().is_last_rank:
|
||||||
# For mid-pipeline stages, return the hidden states.
|
# For mid-pipeline stages, return the hidden states.
|
||||||
if not broadcast_pp_output:
|
|
||||||
if finished_sending or finished_recving:
|
|
||||||
hidden_states.finished_sending = finished_sending
|
|
||||||
hidden_states.finished_recving = finished_recving
|
|
||||||
return hidden_states
|
|
||||||
assert isinstance(hidden_states, IntermediateTensors)
|
assert isinstance(hidden_states, IntermediateTensors)
|
||||||
|
if not broadcast_pp_output:
|
||||||
|
hidden_states.kv_connector_output = kv_connector_output
|
||||||
|
return hidden_states
|
||||||
get_pp_group().send_tensor_dict(hidden_states.tensors,
|
get_pp_group().send_tensor_dict(hidden_states.tensors,
|
||||||
all_gather_group=get_tp_group())
|
all_gather_group=get_tp_group())
|
||||||
logits = None
|
logits = None
|
||||||
else:
|
else:
|
||||||
if self.input_batch.pooling_params:
|
if self.input_batch.pooling_params:
|
||||||
return self._pool(hidden_states, num_scheduled_tokens,
|
return self._pool(hidden_states, num_scheduled_tokens,
|
||||||
num_scheduled_tokens_np, finished_sending,
|
num_scheduled_tokens_np, kv_connector_output)
|
||||||
finished_recving)
|
|
||||||
|
|
||||||
sample_hidden_states = hidden_states[logits_indices]
|
sample_hidden_states = hidden_states[logits_indices]
|
||||||
logits = self.model.compute_logits(sample_hidden_states, None)
|
logits = self.model.compute_logits(sample_hidden_states, None)
|
||||||
@ -1760,8 +1751,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
logprobs=logprobs_lists,
|
logprobs=logprobs_lists,
|
||||||
prompt_logprobs_dict=prompt_logprobs_dict,
|
prompt_logprobs_dict=prompt_logprobs_dict,
|
||||||
pooler_output=[],
|
pooler_output=[],
|
||||||
finished_sending=finished_sending,
|
kv_connector_output=kv_connector_output,
|
||||||
finished_recving=finished_recving,
|
|
||||||
num_nans_in_logits=num_nans_in_logits,
|
num_nans_in_logits=num_nans_in_logits,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -16,8 +16,7 @@ from vllm.config import VllmConfig
|
|||||||
from vllm.distributed import (ensure_model_parallel_initialized,
|
from vllm.distributed import (ensure_model_parallel_initialized,
|
||||||
init_distributed_environment,
|
init_distributed_environment,
|
||||||
set_custom_all_reduce)
|
set_custom_all_reduce)
|
||||||
from vllm.distributed.kv_transfer import (ensure_kv_transfer_initialized,
|
from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized
|
||||||
has_kv_transfer_group)
|
|
||||||
from vllm.distributed.parallel_state import get_pp_group, get_tp_group
|
from vllm.distributed.parallel_state import get_pp_group, get_tp_group
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
@ -369,17 +368,20 @@ class Worker(WorkerBase):
|
|||||||
assert isinstance(output, IntermediateTensors)
|
assert isinstance(output, IntermediateTensors)
|
||||||
get_pp_group().send_tensor_dict(output.tensors,
|
get_pp_group().send_tensor_dict(output.tensors,
|
||||||
all_gather_group=get_tp_group())
|
all_gather_group=get_tp_group())
|
||||||
if not has_kv_transfer_group():
|
|
||||||
|
kv_connector_output = output.kv_connector_output
|
||||||
|
if not kv_connector_output:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# In case of PP with kv transfer, we need to pass through the
|
# In case of PP with kv transfer, we need to pass through the
|
||||||
# finished_sending and finished_recving buffers.
|
# kv_connector_output
|
||||||
new_output = EMPTY_MODEL_RUNNER_OUTPUT
|
if (not kv_connector_output.finished_sending
|
||||||
if output.finished_sending or output.finished_recving:
|
and not kv_connector_output.finished_recving):
|
||||||
new_output = copy.copy(new_output)
|
return EMPTY_MODEL_RUNNER_OUTPUT
|
||||||
new_output.finished_sending = output.finished_sending
|
|
||||||
new_output.finished_recving = output.finished_recving
|
output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT)
|
||||||
output = new_output
|
output.kv_connector_output = kv_connector_output
|
||||||
|
return output
|
||||||
|
|
||||||
assert isinstance(output, ModelRunnerOutput)
|
assert isinstance(output, ModelRunnerOutput)
|
||||||
return output
|
return output
|
||||||
|
|||||||
@ -4,6 +4,8 @@
|
|||||||
Define KV connector functionality mixin for model runners.
|
Define KV connector functionality mixin for model runners.
|
||||||
"""
|
"""
|
||||||
import copy
|
import copy
|
||||||
|
from contextlib import AbstractContextManager, contextmanager, nullcontext
|
||||||
|
from typing import Generator # noqa: UP035
|
||||||
from typing import TYPE_CHECKING, Optional
|
from typing import TYPE_CHECKING, Optional
|
||||||
|
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
@ -12,7 +14,8 @@ from vllm.distributed.kv_transfer import (get_kv_transfer_group,
|
|||||||
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1
|
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1
|
||||||
from vllm.forward_context import get_forward_context, set_forward_context
|
from vllm.forward_context import get_forward_context, set_forward_context
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, ModelRunnerOutput
|
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, KVConnectorOutput,
|
||||||
|
ModelRunnerOutput)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.v1.core.sched.output import SchedulerOutput
|
from vllm.v1.core.sched.output import SchedulerOutput
|
||||||
@ -53,18 +56,60 @@ class KVConnectorModelRunnerMixin:
|
|||||||
scheduler_output.finished_req_ids)
|
scheduler_output.finished_req_ids)
|
||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
def kv_connector_no_forward(self, scheduler_output: "SchedulerOutput",
|
@staticmethod
|
||||||
|
def kv_connector_no_forward(scheduler_output: "SchedulerOutput",
|
||||||
vllm_config: VllmConfig) -> ModelRunnerOutput:
|
vllm_config: VllmConfig) -> ModelRunnerOutput:
|
||||||
# KV send/recv even if no work to do.
|
# KV send/recv even if no work to do.
|
||||||
with set_forward_context(None, vllm_config):
|
with set_forward_context(
|
||||||
self.maybe_setup_kv_connector(scheduler_output)
|
None, vllm_config
|
||||||
finished_sending, finished_recving = (
|
), KVConnectorModelRunnerMixin._get_kv_connector_output(
|
||||||
self.get_finished_kv_transfers(scheduler_output))
|
scheduler_output, wait_for_save=False) as kv_connector_output:
|
||||||
|
pass
|
||||||
|
|
||||||
if not finished_sending and not finished_recving:
|
if (not kv_connector_output.finished_sending
|
||||||
|
and not kv_connector_output.finished_recving):
|
||||||
return EMPTY_MODEL_RUNNER_OUTPUT
|
return EMPTY_MODEL_RUNNER_OUTPUT
|
||||||
|
|
||||||
output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT)
|
output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT)
|
||||||
output.finished_sending = finished_sending
|
output.kv_connector_output = kv_connector_output
|
||||||
output.finished_recving = finished_recving
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def maybe_get_kv_connector_output(
|
||||||
|
scheduler_output: "SchedulerOutput"
|
||||||
|
) -> AbstractContextManager[Optional[KVConnectorOutput]]:
|
||||||
|
return KVConnectorModelRunnerMixin._get_kv_connector_output(
|
||||||
|
scheduler_output) if has_kv_transfer_group() else nullcontext()
|
||||||
|
|
||||||
|
# This context manager must be used within an active forward context.
|
||||||
|
# It encapsulates the entire KV conector lifecycle within execute_model
|
||||||
|
@staticmethod
|
||||||
|
@contextmanager
|
||||||
|
def _get_kv_connector_output(
|
||||||
|
scheduler_output: "SchedulerOutput",
|
||||||
|
wait_for_save: bool = True
|
||||||
|
) -> Generator[KVConnectorOutput, None, None]:
|
||||||
|
output = KVConnectorOutput()
|
||||||
|
|
||||||
|
# Update KVConnector with the KVConnector metadata forward().
|
||||||
|
kv_connector = get_kv_transfer_group()
|
||||||
|
assert isinstance(kv_connector, KVConnectorBase_V1)
|
||||||
|
assert scheduler_output.kv_connector_metadata is not None
|
||||||
|
kv_connector.bind_connector_metadata(
|
||||||
|
scheduler_output.kv_connector_metadata)
|
||||||
|
|
||||||
|
# Background KV cache transfers happen here.
|
||||||
|
# These transfers are designed to be async and the requests
|
||||||
|
# involved may be disjoint from the running requests.
|
||||||
|
# Do this here to save a collective_rpc.
|
||||||
|
kv_connector.start_load_kv(get_forward_context())
|
||||||
|
try:
|
||||||
|
yield output
|
||||||
|
finally:
|
||||||
|
if wait_for_save:
|
||||||
|
kv_connector.wait_for_save()
|
||||||
|
|
||||||
|
output.finished_sending, output.finished_recving = (
|
||||||
|
kv_connector.get_finished(scheduler_output.finished_req_ids))
|
||||||
|
|
||||||
|
kv_connector.clear_connector_metadata()
|
||||||
|
|||||||
@ -51,7 +51,7 @@ from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsLists,
|
|||||||
from vllm.v1.sample.tpu.metadata import TPUSupportedSamplingMetadata
|
from vllm.v1.sample.tpu.metadata import TPUSupportedSamplingMetadata
|
||||||
from vllm.v1.sample.tpu.sampler import Sampler as TPUSampler
|
from vllm.v1.sample.tpu.sampler import Sampler as TPUSampler
|
||||||
from vllm.v1.worker.kv_connector_model_runner_mixin import (
|
from vllm.v1.worker.kv_connector_model_runner_mixin import (
|
||||||
KVConnectorModelRunnerMixin)
|
KVConnectorModelRunnerMixin, KVConnectorOutput)
|
||||||
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
|
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
|
||||||
from vllm.v1.worker.tpu_input_batch import CachedRequestState, InputBatch
|
from vllm.v1.worker.tpu_input_batch import CachedRequestState, InputBatch
|
||||||
|
|
||||||
@ -1175,9 +1175,10 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
logprobs=logprobs_lists,
|
logprobs=logprobs_lists,
|
||||||
prompt_logprobs_dict=prompt_logprobs_dict,
|
prompt_logprobs_dict=prompt_logprobs_dict,
|
||||||
pooler_output=[],
|
pooler_output=[],
|
||||||
finished_sending=finished_sending,
|
kv_connector_output=KVConnectorOutput(
|
||||||
finished_recving=finished_recving,
|
finished_sending=finished_sending,
|
||||||
)
|
finished_recving=finished_recving,
|
||||||
|
))
|
||||||
|
|
||||||
# Check there are no new graphs compiled - all the graphs should be
|
# Check there are no new graphs compiled - all the graphs should be
|
||||||
# captured and compiled during warm up.
|
# captured and compiled during warm up.
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user