mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-29 14:40:54 +08:00
[BugFix] Fix PP performance and PP kv connector output regression (#28768)
Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
parent
d8874c61a5
commit
7765e5ba75
@ -63,7 +63,6 @@ from vllm.v1.outputs import ModelRunnerOutput
|
||||
from vllm.v1.request import Request, RequestStatus
|
||||
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
|
||||
from vllm.v1.structured_output import StructuredOutputManager
|
||||
from vllm.v1.utils import record_function_or_nullcontext
|
||||
from vllm.version import __version__ as VLLM_VERSION
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -181,11 +180,13 @@ class EngineCore:
|
||||
logger.info("Batch queue is enabled with size %d", self.batch_queue_size)
|
||||
self.batch_queue = deque(maxlen=self.batch_queue_size)
|
||||
|
||||
self.ec_producer = (
|
||||
vllm_config.ec_transfer_config is not None
|
||||
and vllm_config.ec_transfer_config.is_ec_producer
|
||||
)
|
||||
|
||||
self.request_block_hasher: Callable[[Request], list[BlockHash]] | None = None
|
||||
if (
|
||||
self.vllm_config.cache_config.enable_prefix_caching
|
||||
or kv_connector is not None
|
||||
):
|
||||
if vllm_config.cache_config.enable_prefix_caching or kv_connector is not None:
|
||||
caching_hash_fn = get_hash_fn_by_name(
|
||||
vllm_config.cache_config.prefix_caching_hash_algo
|
||||
)
|
||||
@ -246,7 +247,7 @@ class EngineCore:
|
||||
|
||||
elapsed = time.time() - start
|
||||
logger.info_once(
|
||||
("init engine (profile, create kv cache, warmup model) took %.2f seconds"),
|
||||
"init engine (profile, create kv cache, warmup model) took %.2f seconds",
|
||||
elapsed,
|
||||
scope="local",
|
||||
)
|
||||
@ -312,6 +313,16 @@ class EngineCore:
|
||||
)
|
||||
raise err
|
||||
|
||||
def _log_err_callback(self, scheduler_output: SchedulerOutput):
|
||||
"""Log error details of a future that's not expected to return a result."""
|
||||
|
||||
def callback(f, sched_output=scheduler_output):
|
||||
with self.log_error_detail(sched_output):
|
||||
result = f.result()
|
||||
assert result is None
|
||||
|
||||
return callback
|
||||
|
||||
def step(self) -> tuple[dict[int, EngineCoreOutputs], bool]:
|
||||
"""Schedule, execute, and make output.
|
||||
|
||||
@ -323,21 +334,17 @@ class EngineCore:
|
||||
# or finished and not yet removed from the batch.
|
||||
if not self.scheduler.has_requests():
|
||||
return {}, False
|
||||
with record_function_or_nullcontext("core step: schedule"):
|
||||
scheduler_output = self.scheduler.schedule()
|
||||
scheduler_output = self.scheduler.schedule()
|
||||
future = self.model_executor.execute_model(scheduler_output, non_block=True)
|
||||
grammar_output = self.scheduler.get_grammar_bitmask(scheduler_output)
|
||||
with self.log_error_detail(scheduler_output):
|
||||
model_output = future.result()
|
||||
if model_output is None:
|
||||
model_output = self.model_executor.sample_tokens(grammar_output)
|
||||
|
||||
with record_function_or_nullcontext("core step: execute_model"):
|
||||
future = self.model_executor.execute_model(scheduler_output, non_block=True)
|
||||
grammar_output = self.scheduler.get_grammar_bitmask(scheduler_output)
|
||||
with self.log_error_detail(scheduler_output):
|
||||
model_output = future.result()
|
||||
if model_output is None:
|
||||
model_output = self.model_executor.sample_tokens(grammar_output)
|
||||
|
||||
with record_function_or_nullcontext("core step: update_from_output"):
|
||||
engine_core_outputs = self.scheduler.update_from_output(
|
||||
scheduler_output, model_output
|
||||
)
|
||||
engine_core_outputs = self.scheduler.update_from_output(
|
||||
scheduler_output, model_output
|
||||
)
|
||||
|
||||
return engine_core_outputs, scheduler_output.total_num_scheduled_tokens > 0
|
||||
|
||||
@ -378,52 +385,34 @@ class EngineCore:
|
||||
model_executed = False
|
||||
deferred_scheduler_output = None
|
||||
if self.scheduler.has_requests():
|
||||
with record_function_or_nullcontext("core step_with_batch_queue: schedule"):
|
||||
scheduler_output = self.scheduler.schedule()
|
||||
with record_function_or_nullcontext(
|
||||
"core step_with_batch_queue: execute_model"
|
||||
):
|
||||
exec_future = self.model_executor.execute_model(
|
||||
scheduler_output, non_block=True
|
||||
)
|
||||
model_executed = scheduler_output.total_num_scheduled_tokens > 0
|
||||
scheduler_output = self.scheduler.schedule()
|
||||
exec_future = self.model_executor.execute_model(
|
||||
scheduler_output, non_block=True
|
||||
)
|
||||
if not self.ec_producer:
|
||||
model_executed = scheduler_output.total_num_scheduled_tokens > 0
|
||||
|
||||
if scheduler_output.pending_structured_output_tokens:
|
||||
with record_function_or_nullcontext(
|
||||
"core step_with_batch_queue: pending_structured_output_tokens"
|
||||
):
|
||||
# We need to defer sampling until we have processed the model output
|
||||
# from the prior step.
|
||||
deferred_scheduler_output = scheduler_output
|
||||
# Block-wait for execute to return
|
||||
# (continues running async on the GPU).
|
||||
with self.log_error_detail(scheduler_output):
|
||||
exec_result = exec_future.result()
|
||||
assert exec_result is None
|
||||
if not model_executed:
|
||||
# No sampling required (no requests scheduled).
|
||||
future = cast(Future[ModelRunnerOutput], exec_future)
|
||||
else:
|
||||
with record_function_or_nullcontext(
|
||||
"core step_with_batch_queue: get_grammar_bitmask"
|
||||
):
|
||||
# We aren't waiting for any tokens, get any grammar
|
||||
# output immediately.
|
||||
exec_future.add_done_callback(self._log_err_callback(scheduler_output))
|
||||
|
||||
if not scheduler_output.pending_structured_output_tokens:
|
||||
# We aren't waiting for any tokens, get any grammar output
|
||||
# and sample immediately.
|
||||
grammar_output = self.scheduler.get_grammar_bitmask(
|
||||
scheduler_output
|
||||
)
|
||||
# Block-wait for execute to return (continues running async on the GPU).
|
||||
with self.log_error_detail(scheduler_output):
|
||||
exec_result = exec_future.result()
|
||||
|
||||
if exec_result is None:
|
||||
with record_function_or_nullcontext(
|
||||
"core step_with_batch_queue: sample_tokens"
|
||||
):
|
||||
# Call sample tokens.
|
||||
future = self.model_executor.sample_tokens(
|
||||
grammar_output, non_block=True
|
||||
)
|
||||
future = self.model_executor.sample_tokens(
|
||||
grammar_output, non_block=True
|
||||
)
|
||||
else:
|
||||
# No sampling required (e.g. all requests finished).
|
||||
future = cast(Future[ModelRunnerOutput], exec_future)
|
||||
# We need to defer sampling until we have processed the model output
|
||||
# from the prior step.
|
||||
deferred_scheduler_output = scheduler_output
|
||||
|
||||
if not deferred_scheduler_output:
|
||||
# Add this step's future to the queue.
|
||||
batch_queue.appendleft((future, scheduler_output))
|
||||
if (
|
||||
@ -440,34 +429,27 @@ class EngineCore:
|
||||
# only be called when the scheduler contains requests or the queue
|
||||
# is non-empty.
|
||||
return None, False
|
||||
with record_function_or_nullcontext("core step_with_batch_queue: model_output"):
|
||||
# Block until the next result is available.
|
||||
future, scheduler_output = batch_queue.pop()
|
||||
with self.log_error_detail(scheduler_output):
|
||||
model_output = future.result()
|
||||
with record_function_or_nullcontext(
|
||||
"core step_with_batch_queue: update_from_output"
|
||||
):
|
||||
engine_core_outputs = self.scheduler.update_from_output(
|
||||
scheduler_output, model_output
|
||||
)
|
||||
|
||||
# Block until the next result is available.
|
||||
future, scheduler_output = batch_queue.pop()
|
||||
with self.log_error_detail(scheduler_output):
|
||||
model_output = future.result()
|
||||
|
||||
engine_core_outputs = self.scheduler.update_from_output(
|
||||
scheduler_output, model_output
|
||||
)
|
||||
|
||||
# NOTE(nick): We can either handle the deferred tasks here or save
|
||||
# in a field and do it immediately once step_with_batch_queue is
|
||||
# re-called. The latter slightly favors TTFT over TPOT/throughput.
|
||||
if deferred_scheduler_output:
|
||||
with record_function_or_nullcontext(
|
||||
"core step_with_batch_queue: deferred_scheduler_output"
|
||||
):
|
||||
# We now have the tokens needed to compute the bitmask for the
|
||||
# deferred request. Get the bitmask and call sample tokens.
|
||||
grammar_output = self.scheduler.get_grammar_bitmask(
|
||||
deferred_scheduler_output
|
||||
)
|
||||
future = self.model_executor.sample_tokens(
|
||||
grammar_output, non_block=True
|
||||
)
|
||||
batch_queue.appendleft((future, deferred_scheduler_output))
|
||||
# We now have the tokens needed to compute the bitmask for the
|
||||
# deferred request. Get the bitmask and call sample tokens.
|
||||
grammar_output = self.scheduler.get_grammar_bitmask(
|
||||
deferred_scheduler_output
|
||||
)
|
||||
future = self.model_executor.sample_tokens(grammar_output, non_block=True)
|
||||
batch_queue.appendleft((future, deferred_scheduler_output))
|
||||
|
||||
return engine_core_outputs, model_executed
|
||||
|
||||
|
||||
@ -99,6 +99,11 @@ class RayDistributedExecutor(Executor):
|
||||
# KV connector setup
|
||||
self.has_connector = self.vllm_config.kv_transfer_config is not None
|
||||
|
||||
self.ec_producer = (
|
||||
self.vllm_config.ec_transfer_config is not None
|
||||
and self.vllm_config.ec_transfer_config.is_ec_producer
|
||||
)
|
||||
|
||||
self.scheduler_output: SchedulerOutput | None = None
|
||||
|
||||
@property
|
||||
@ -395,6 +400,12 @@ class RayDistributedExecutor(Executor):
|
||||
"State error: sample_tokens() must be called "
|
||||
"after execute_model() returns None."
|
||||
)
|
||||
|
||||
if self.ec_producer or not scheduler_output.total_num_scheduled_tokens:
|
||||
# Model will not execute, call model runner immediately.
|
||||
return self._execute_dag(scheduler_output, None, non_block)
|
||||
|
||||
# Model will execute, defer to sample_tokens() call.
|
||||
self.scheduler_output = scheduler_output
|
||||
return COMPLETED_NONE_FUTURE if non_block else None
|
||||
|
||||
@ -417,10 +428,18 @@ class RayDistributedExecutor(Executor):
|
||||
"""
|
||||
scheduler_output = self.scheduler_output
|
||||
if scheduler_output is None:
|
||||
return None # noqa
|
||||
return COMPLETED_NONE_FUTURE if non_block else None # noqa
|
||||
|
||||
self.scheduler_output = None
|
||||
|
||||
return self._execute_dag(scheduler_output, grammar_output, non_block)
|
||||
|
||||
def _execute_dag(
|
||||
self,
|
||||
scheduler_output: SchedulerOutput,
|
||||
grammar_output: "GrammarOutput | None",
|
||||
non_block: bool = False,
|
||||
) -> ModelRunnerOutput | Future[ModelRunnerOutput]:
|
||||
# Build the compiled DAG for the first time.
|
||||
if self.forward_dag is None: # type: ignore
|
||||
self.forward_dag = self._compiled_ray_dag(enable_asyncio=False)
|
||||
|
||||
@ -7,7 +7,7 @@ import time
|
||||
from collections import defaultdict
|
||||
from collections.abc import Iterator
|
||||
from contextlib import contextmanager
|
||||
from copy import deepcopy
|
||||
from copy import copy, deepcopy
|
||||
from functools import reduce
|
||||
from itertools import product
|
||||
from typing import TYPE_CHECKING, Any, NamedTuple, TypeAlias, cast
|
||||
@ -250,7 +250,6 @@ class ExecuteModelState(NamedTuple):
|
||||
hidden_states: torch.Tensor
|
||||
sample_hidden_states: torch.Tensor
|
||||
aux_hidden_states: list[torch.Tensor] | None
|
||||
kv_connector_output: KVConnectorOutput | None
|
||||
ec_connector_output: ECConnectorOutput | None
|
||||
|
||||
|
||||
@ -573,6 +572,7 @@ class GPUModelRunner(
|
||||
|
||||
# Ephemeral state transferred between execute_model() and sample_tokens().
|
||||
self.execute_model_state: ExecuteModelState | None = None
|
||||
self.kv_connector_output: KVConnectorOutput | None = None
|
||||
|
||||
def reset_mm_cache(self) -> None:
|
||||
if self.mm_budget:
|
||||
@ -2803,6 +2803,7 @@ class GPUModelRunner(
|
||||
# Return the intermediate tensors.
|
||||
assert isinstance(hidden_states, IntermediateTensors)
|
||||
hidden_states.kv_connector_output = kv_connector_output
|
||||
self.kv_connector_output = kv_connector_output
|
||||
return hidden_states
|
||||
|
||||
if self.is_pooling_model:
|
||||
@ -2853,18 +2854,31 @@ class GPUModelRunner(
|
||||
hidden_states,
|
||||
sample_hidden_states,
|
||||
aux_hidden_states,
|
||||
kv_connector_output,
|
||||
ec_connector_output,
|
||||
)
|
||||
self.kv_connector_output = kv_connector_output
|
||||
return None
|
||||
|
||||
@torch.inference_mode
|
||||
def sample_tokens(
|
||||
self, grammar_output: "GrammarOutput | None"
|
||||
) -> ModelRunnerOutput | AsyncModelRunnerOutput | IntermediateTensors:
|
||||
kv_connector_output = self.kv_connector_output
|
||||
self.kv_connector_output = None
|
||||
|
||||
if self.execute_model_state is None:
|
||||
# Nothing to do (PP non-final rank case), output isn't used.
|
||||
return None # noqa
|
||||
if not kv_connector_output:
|
||||
return None # noqa
|
||||
|
||||
# In case of PP with kv transfer, we need to pass through the
|
||||
# kv_connector_output
|
||||
if kv_connector_output.is_empty():
|
||||
return EMPTY_MODEL_RUNNER_OUTPUT
|
||||
|
||||
output = copy(EMPTY_MODEL_RUNNER_OUTPUT)
|
||||
output.kv_connector_output = kv_connector_output
|
||||
return output
|
||||
|
||||
# Unpack ephemeral state.
|
||||
(
|
||||
@ -2875,7 +2889,6 @@ class GPUModelRunner(
|
||||
hidden_states,
|
||||
sample_hidden_states,
|
||||
aux_hidden_states,
|
||||
kv_connector_output,
|
||||
ec_connector_output,
|
||||
) = self.execute_model_state
|
||||
# Clear ephemeral state.
|
||||
|
||||
@ -2,7 +2,6 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""A GPU worker class."""
|
||||
|
||||
import copy
|
||||
import gc
|
||||
import os
|
||||
from contextlib import AbstractContextManager, nullcontext
|
||||
@ -45,7 +44,6 @@ from vllm.v1.core.sched.output import GrammarOutput
|
||||
from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
|
||||
from vllm.v1.outputs import (
|
||||
EMPTY_MODEL_RUNNER_OUTPUT,
|
||||
AsyncModelRunnerOutput,
|
||||
DraftTokenIds,
|
||||
ModelRunnerOutput,
|
||||
@ -581,18 +579,7 @@ class Worker(WorkerBase):
|
||||
all_gather_tensors=all_gather_tensors,
|
||||
)
|
||||
|
||||
kv_connector_output = output.kv_connector_output
|
||||
if not kv_connector_output:
|
||||
return None
|
||||
|
||||
# In case of PP with kv transfer, we need to pass through the
|
||||
# kv_connector_output
|
||||
if kv_connector_output.is_empty():
|
||||
return EMPTY_MODEL_RUNNER_OUTPUT
|
||||
|
||||
output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT)
|
||||
output.kv_connector_output = kv_connector_output
|
||||
return output
|
||||
return None
|
||||
|
||||
def take_draft_token_ids(self) -> DraftTokenIds | None:
|
||||
return self.model_runner.take_draft_token_ids()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user