[V1] Eagerly remove finished requests from the batch (#14388)

Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
Nick Hill 2025-03-07 10:56:00 -08:00 committed by GitHub
parent c6359e8ca6
commit 8ed5421aaa
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 58 additions and 16 deletions

View File

@ -102,14 +102,24 @@ def test_engine_core(monkeypatch):
engine_core.add_request(req)
assert len(engine_core.scheduler.waiting) == 1
assert len(engine_core.scheduler.running) == 0
assert engine_core.scheduler.has_unfinished_requests()
assert not engine_core.scheduler.has_finished_requests()
_ = engine_core.step()
assert len(engine_core.scheduler.waiting) == 0
assert len(engine_core.scheduler.running) == 1
assert engine_core.scheduler.has_unfinished_requests()
assert not engine_core.scheduler.has_finished_requests()
engine_core.abort_requests([request_id])
assert len(engine_core.scheduler.waiting) == 0
assert len(engine_core.scheduler.running) == 0
assert not engine_core.scheduler.has_unfinished_requests()
assert engine_core.scheduler.has_finished_requests()
_ = engine_core.step()
assert not engine_core.scheduler.has_unfinished_requests()
assert not engine_core.scheduler.has_finished_requests()
# Add, step, abort 1 of the 3.
req0 = make_request()

View File

@ -50,7 +50,7 @@ def loop_until_done(client: EngineCoreClient, outputs: dict):
engine_core_outputs = client.get_output().outputs
if len(engine_core_outputs) == 0:
break
continue
all_finished = True
for out in engine_core_outputs:
@ -68,7 +68,7 @@ async def loop_until_done_async(client: EngineCoreClient, outputs: dict):
engine_core_outputs = (await client.get_output_async()).outputs
if len(engine_core_outputs) == 0:
break
continue
all_finished = True
for out in engine_core_outputs:

View File

@ -682,7 +682,8 @@ class Scheduler:
assert RequestStatus.is_finished(finished_status)
if isinstance(request_ids, str):
request_ids = (request_ids, )
request_ids = set(request_ids)
else:
request_ids = set(request_ids)
for req_id in request_ids:
request = self.requests.get(req_id)
@ -714,6 +715,14 @@ class Scheduler:
def has_unfinished_requests(self) -> bool:
return self.get_num_unfinished_requests() > 0
def has_finished_requests(self) -> bool:
return len(self.finished_req_ids) > 0
def has_requests(self):
"""Returns True if there are unfinished requests, or finished requests
not yet returned in SchedulerOutputs."""
return self.has_unfinished_requests() or self.has_finished_requests()
def get_num_unscheduled_requests(self) -> int:
"""Number of requests that are not being processed by the executor."""
return self.get_num_unfinished_requests() - len(self.scheduled_req_ids)

View File

@ -253,13 +253,14 @@ class AsyncLLM(EngineClient):
while True:
# 1) Pull EngineCoreOutputs from the EngineCore.
outputs = await self.engine_core.get_output_async()
num_outputs = len(outputs.outputs)
iteration_stats = IterationStats() if self.log_stats else None
iteration_stats = IterationStats() if (
self.log_stats and num_outputs) else None
# Split outputs into chunks of at most
# VLLM_V1_OUTPUT_PROC_CHUNK_SIZE, so that we don't block the
# event loop for too long.
num_outputs = len(outputs.outputs)
if num_outputs <= VLLM_V1_OUTPUT_PROC_CHUNK_SIZE:
slices = (outputs.outputs, )
else:
@ -313,7 +314,6 @@ class AsyncLLM(EngineClient):
return
assert scheduler_stats is not None
assert iteration_stats is not None
for stat_logger in self.stat_loggers:
stat_logger.record(scheduler_stats=scheduler_stats,
iteration_stats=iteration_stats)

View File

@ -153,7 +153,9 @@ class EngineCore:
def step(self) -> EngineCoreOutputs:
"""Schedule, execute, and make output."""
if not self.scheduler.has_unfinished_requests():
# Check for any requests remaining in the scheduler - unfinished,
# or finished and not yet removed from the batch.
if not self.scheduler.has_requests():
return EngineCoreOutputs(
outputs=[],
scheduler_stats=self.scheduler.make_stats(),
@ -335,7 +337,7 @@ class EngineCoreProc(EngineCore):
# Loop until process is sent a SIGINT or SIGTERM
while True:
# 1) Poll the input queue until there is work to do.
while not self.scheduler.has_unfinished_requests():
while not self.scheduler.has_requests():
logger.debug("EngineCore busy loop waiting.")
req = self.input_queue.get()
self._handle_client_request(*req)

View File

@ -22,7 +22,7 @@ class StatLoggerBase(ABC):
@abstractmethod
def record(self, scheduler_stats: SchedulerStats,
iteration_stats: IterationStats):
iteration_stats: Optional[IterationStats]):
...
def log(self): # noqa
@ -56,10 +56,11 @@ class LoggingStatLogger(StatLoggerBase):
return float(np.sum(tracked_stats) / (now - self.last_log_time))
def record(self, scheduler_stats: SchedulerStats,
iteration_stats: IterationStats):
iteration_stats: Optional[IterationStats]):
"""Log Stats to standard output."""
self._track_iteration_stats(iteration_stats)
if iteration_stats:
self._track_iteration_stats(iteration_stats)
self.prefix_caching_metrics.observe(scheduler_stats.prefix_cache_stats)
@ -319,7 +320,7 @@ class PrometheusStatLogger(StatLoggerBase):
info_gauge.set(1)
def record(self, scheduler_stats: SchedulerStats,
iteration_stats: IterationStats):
iteration_stats: Optional[IterationStats]):
"""Log to prometheus."""
self.gauge_scheduler_running.set(scheduler_stats.num_running_reqs)
self.gauge_scheduler_waiting.set(scheduler_stats.num_waiting_reqs)
@ -331,6 +332,9 @@ class PrometheusStatLogger(StatLoggerBase):
self.counter_gpu_prefix_cache_hits.inc(
scheduler_stats.prefix_cache_stats.hits)
if iteration_stats is None:
return
self.counter_num_preempted_reqs.inc(iteration_stats.num_preempted_reqs)
self.counter_prompt_tokens.inc(iteration_stats.num_prompt_tokens)
self.counter_generation_tokens.inc(

View File

@ -80,3 +80,13 @@ class ModelRunnerOutput:
# [prompt_len, num_prompt_logprobs]
# [prompt_len]
prompt_logprobs_dict: dict[str, Optional[LogprobsTensors]]
EMPTY_MODEL_RUNNER_OUTPUT = ModelRunnerOutput(
req_ids=[],
req_id_to_index={},
sampled_token_ids=[],
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
)

View File

@ -32,7 +32,8 @@ from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
from vllm.v1.engine.mm_input_cache import MMInputCacheClient
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheSpec)
from vllm.v1.outputs import LogprobsTensors, ModelRunnerOutput
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors,
ModelRunnerOutput)
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.rejection_sampler import INVALID_TOKEN_ID, RejectionSampler
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
@ -919,6 +920,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> Union[ModelRunnerOutput, torch.Tensor]:
self._update_states(scheduler_output)
if not scheduler_output.total_num_scheduled_tokens:
# Return empty ModelRunnerOuptut if there's no work to do.
return EMPTY_MODEL_RUNNER_OUTPUT
if self.is_multimodal_model:
# Run the multimodal encoder if any.
@ -1069,7 +1073,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
spec_token_ids = self.generate_draft_token_ids(
valid_sampled_token_ids)
model_runner_output = ModelRunnerOutput(
return ModelRunnerOutput(
req_ids=self.input_batch.req_ids,
req_id_to_index=self.input_batch.req_id_to_index,
sampled_token_ids=valid_sampled_token_ids,
@ -1077,7 +1081,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
logprobs=logprobs_lists,
prompt_logprobs_dict=prompt_logprobs_dict,
)
return model_runner_output
def generate_draft_token_ids(
self,

View File

@ -29,7 +29,8 @@ from vllm.v1.attention.backends.pallas import (NUM_KV_PAGES_PER_BLOCK,
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheSpec)
from vllm.v1.outputs import LogprobsTensors, ModelRunnerOutput
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors,
ModelRunnerOutput)
from vllm.v1.utils import bind_kv_cache
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
@ -546,6 +547,9 @@ class TPUModelRunner:
) -> ModelRunnerOutput:
# Update cached state
self._update_states(scheduler_output)
if not scheduler_output.total_num_scheduled_tokens:
# Return empty ModelRunnerOuptut if there's no work to do.
return EMPTY_MODEL_RUNNER_OUTPUT
if self.is_multimodal_model:
# Run the multimodal encoder if any.