diff --git a/tests/v1/engine/test_engine_core.py b/tests/v1/engine/test_engine_core.py index 11c22effb122f..5fdbcf5b99636 100644 --- a/tests/v1/engine/test_engine_core.py +++ b/tests/v1/engine/test_engine_core.py @@ -102,14 +102,24 @@ def test_engine_core(monkeypatch): engine_core.add_request(req) assert len(engine_core.scheduler.waiting) == 1 assert len(engine_core.scheduler.running) == 0 + assert engine_core.scheduler.has_unfinished_requests() + assert not engine_core.scheduler.has_finished_requests() _ = engine_core.step() assert len(engine_core.scheduler.waiting) == 0 assert len(engine_core.scheduler.running) == 1 + assert engine_core.scheduler.has_unfinished_requests() + assert not engine_core.scheduler.has_finished_requests() engine_core.abort_requests([request_id]) assert len(engine_core.scheduler.waiting) == 0 assert len(engine_core.scheduler.running) == 0 + assert not engine_core.scheduler.has_unfinished_requests() + assert engine_core.scheduler.has_finished_requests() + + _ = engine_core.step() + assert not engine_core.scheduler.has_unfinished_requests() + assert not engine_core.scheduler.has_finished_requests() # Add, step, abort 1 of the 3. req0 = make_request() diff --git a/tests/v1/engine/test_engine_core_client.py b/tests/v1/engine/test_engine_core_client.py index 3880a3dd9b8ae..e646ccbd46030 100644 --- a/tests/v1/engine/test_engine_core_client.py +++ b/tests/v1/engine/test_engine_core_client.py @@ -50,7 +50,7 @@ def loop_until_done(client: EngineCoreClient, outputs: dict): engine_core_outputs = client.get_output().outputs if len(engine_core_outputs) == 0: - break + continue all_finished = True for out in engine_core_outputs: @@ -68,7 +68,7 @@ async def loop_until_done_async(client: EngineCoreClient, outputs: dict): engine_core_outputs = (await client.get_output_async()).outputs if len(engine_core_outputs) == 0: - break + continue all_finished = True for out in engine_core_outputs: diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index 70e36e2dc1528..a7e50f8f40ece 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -682,7 +682,8 @@ class Scheduler: assert RequestStatus.is_finished(finished_status) if isinstance(request_ids, str): request_ids = (request_ids, ) - request_ids = set(request_ids) + else: + request_ids = set(request_ids) for req_id in request_ids: request = self.requests.get(req_id) @@ -714,6 +715,14 @@ class Scheduler: def has_unfinished_requests(self) -> bool: return self.get_num_unfinished_requests() > 0 + def has_finished_requests(self) -> bool: + return len(self.finished_req_ids) > 0 + + def has_requests(self): + """Returns True if there are unfinished requests, or finished requests + not yet returned in SchedulerOutputs.""" + return self.has_unfinished_requests() or self.has_finished_requests() + def get_num_unscheduled_requests(self) -> int: """Number of requests that are not being processed by the executor.""" return self.get_num_unfinished_requests() - len(self.scheduled_req_ids) diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index 32cbc10e16f66..3dc513a728339 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -253,13 +253,14 @@ class AsyncLLM(EngineClient): while True: # 1) Pull EngineCoreOutputs from the EngineCore. outputs = await self.engine_core.get_output_async() + num_outputs = len(outputs.outputs) - iteration_stats = IterationStats() if self.log_stats else None + iteration_stats = IterationStats() if ( + self.log_stats and num_outputs) else None # Split outputs into chunks of at most # VLLM_V1_OUTPUT_PROC_CHUNK_SIZE, so that we don't block the # event loop for too long. - num_outputs = len(outputs.outputs) if num_outputs <= VLLM_V1_OUTPUT_PROC_CHUNK_SIZE: slices = (outputs.outputs, ) else: @@ -313,7 +314,6 @@ class AsyncLLM(EngineClient): return assert scheduler_stats is not None - assert iteration_stats is not None for stat_logger in self.stat_loggers: stat_logger.record(scheduler_stats=scheduler_stats, iteration_stats=iteration_stats) diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index e60aa5d45810f..bdf9203b1b1d5 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -153,7 +153,9 @@ class EngineCore: def step(self) -> EngineCoreOutputs: """Schedule, execute, and make output.""" - if not self.scheduler.has_unfinished_requests(): + # Check for any requests remaining in the scheduler - unfinished, + # or finished and not yet removed from the batch. + if not self.scheduler.has_requests(): return EngineCoreOutputs( outputs=[], scheduler_stats=self.scheduler.make_stats(), @@ -335,7 +337,7 @@ class EngineCoreProc(EngineCore): # Loop until process is sent a SIGINT or SIGTERM while True: # 1) Poll the input queue until there is work to do. - while not self.scheduler.has_unfinished_requests(): + while not self.scheduler.has_requests(): logger.debug("EngineCore busy loop waiting.") req = self.input_queue.get() self._handle_client_request(*req) diff --git a/vllm/v1/metrics/loggers.py b/vllm/v1/metrics/loggers.py index 44493709b639a..fcb4d4f5a25a6 100644 --- a/vllm/v1/metrics/loggers.py +++ b/vllm/v1/metrics/loggers.py @@ -22,7 +22,7 @@ class StatLoggerBase(ABC): @abstractmethod def record(self, scheduler_stats: SchedulerStats, - iteration_stats: IterationStats): + iteration_stats: Optional[IterationStats]): ... def log(self): # noqa @@ -56,10 +56,11 @@ class LoggingStatLogger(StatLoggerBase): return float(np.sum(tracked_stats) / (now - self.last_log_time)) def record(self, scheduler_stats: SchedulerStats, - iteration_stats: IterationStats): + iteration_stats: Optional[IterationStats]): """Log Stats to standard output.""" - self._track_iteration_stats(iteration_stats) + if iteration_stats: + self._track_iteration_stats(iteration_stats) self.prefix_caching_metrics.observe(scheduler_stats.prefix_cache_stats) @@ -319,7 +320,7 @@ class PrometheusStatLogger(StatLoggerBase): info_gauge.set(1) def record(self, scheduler_stats: SchedulerStats, - iteration_stats: IterationStats): + iteration_stats: Optional[IterationStats]): """Log to prometheus.""" self.gauge_scheduler_running.set(scheduler_stats.num_running_reqs) self.gauge_scheduler_waiting.set(scheduler_stats.num_waiting_reqs) @@ -331,6 +332,9 @@ class PrometheusStatLogger(StatLoggerBase): self.counter_gpu_prefix_cache_hits.inc( scheduler_stats.prefix_cache_stats.hits) + if iteration_stats is None: + return + self.counter_num_preempted_reqs.inc(iteration_stats.num_preempted_reqs) self.counter_prompt_tokens.inc(iteration_stats.num_prompt_tokens) self.counter_generation_tokens.inc( diff --git a/vllm/v1/outputs.py b/vllm/v1/outputs.py index dc3ad402e0665..edae654b5d339 100644 --- a/vllm/v1/outputs.py +++ b/vllm/v1/outputs.py @@ -80,3 +80,13 @@ class ModelRunnerOutput: # [prompt_len, num_prompt_logprobs] # [prompt_len] prompt_logprobs_dict: dict[str, Optional[LogprobsTensors]] + + +EMPTY_MODEL_RUNNER_OUTPUT = ModelRunnerOutput( + req_ids=[], + req_id_to_index={}, + sampled_token_ids=[], + spec_token_ids=None, + logprobs=None, + prompt_logprobs_dict={}, +) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 2484f0799b824..5cd7e25edcaaa 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -32,7 +32,8 @@ from vllm.v1.core.encoder_cache_manager import compute_encoder_budget from vllm.v1.engine.mm_input_cache import MMInputCacheClient from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, KVCacheSpec) -from vllm.v1.outputs import LogprobsTensors, ModelRunnerOutput +from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors, + ModelRunnerOutput) from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.rejection_sampler import INVALID_TOKEN_ID, RejectionSampler from vllm.v1.spec_decode.ngram_proposer import NgramProposer @@ -919,6 +920,9 @@ class GPUModelRunner(LoRAModelRunnerMixin): intermediate_tensors: Optional[IntermediateTensors] = None, ) -> Union[ModelRunnerOutput, torch.Tensor]: self._update_states(scheduler_output) + if not scheduler_output.total_num_scheduled_tokens: + # Return empty ModelRunnerOuptut if there's no work to do. + return EMPTY_MODEL_RUNNER_OUTPUT if self.is_multimodal_model: # Run the multimodal encoder if any. @@ -1069,7 +1073,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): spec_token_ids = self.generate_draft_token_ids( valid_sampled_token_ids) - model_runner_output = ModelRunnerOutput( + return ModelRunnerOutput( req_ids=self.input_batch.req_ids, req_id_to_index=self.input_batch.req_id_to_index, sampled_token_ids=valid_sampled_token_ids, @@ -1077,7 +1081,6 @@ class GPUModelRunner(LoRAModelRunnerMixin): logprobs=logprobs_lists, prompt_logprobs_dict=prompt_logprobs_dict, ) - return model_runner_output def generate_draft_token_ids( self, diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index f661412d9378e..d4ebb3adcf8dc 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -29,7 +29,8 @@ from vllm.v1.attention.backends.pallas import (NUM_KV_PAGES_PER_BLOCK, from vllm.v1.core.encoder_cache_manager import compute_encoder_budget from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, KVCacheSpec) -from vllm.v1.outputs import LogprobsTensors, ModelRunnerOutput +from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors, + ModelRunnerOutput) from vllm.v1.utils import bind_kv_cache from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch @@ -546,6 +547,9 @@ class TPUModelRunner: ) -> ModelRunnerOutput: # Update cached state self._update_states(scheduler_output) + if not scheduler_output.total_num_scheduled_tokens: + # Return empty ModelRunnerOuptut if there's no work to do. + return EMPTY_MODEL_RUNNER_OUTPUT if self.is_multimodal_model: # Run the multimodal encoder if any.