From 6f7de33bed412869bec4631add885e5ff88c22cf Mon Sep 17 00:00:00 2001 From: Mark McLoughlin Date: Mon, 10 Nov 2025 08:34:36 +0000 Subject: [PATCH] [Metrics] Refactor LoRA state tracking (#26801) Signed-off-by: Mark McLoughlin --- tests/v1/engine/test_output_processor.py | 175 ++++++++++++++++++++++- tests/v1/metrics/test_stats.py | 18 +-- vllm/v1/engine/async_llm.py | 2 + vllm/v1/engine/llm_engine.py | 1 + vllm/v1/engine/output_processor.py | 23 +-- vllm/v1/metrics/loggers.py | 28 ++-- vllm/v1/metrics/stats.py | 125 ++++++++-------- 7 files changed, 267 insertions(+), 105 deletions(-) diff --git a/tests/v1/engine/test_output_processor.py b/tests/v1/engine/test_output_processor.py index 28ebe5166d962..d77a119ec60f8 100644 --- a/tests/v1/engine/test_output_processor.py +++ b/tests/v1/engine/test_output_processor.py @@ -15,12 +15,19 @@ from tests.v1.engine.utils import ( ) from vllm import PoolingParams from vllm.logprobs import PromptLogprobs, SampleLogprobs +from vllm.lora.request import LoRARequest from vllm.outputs import CompletionOutput, RequestOutput from vllm.sampling_params import RequestOutputKind, SamplingParams from vllm.transformers_utils.tokenizer import AnyTokenizer -from vllm.v1.engine import EngineCoreRequest +from vllm.v1.engine import ( + EngineCoreEvent, + EngineCoreEventType, + EngineCoreOutputs, + EngineCoreRequest, + FinishReason, +) from vllm.v1.engine.output_processor import OutputProcessor, RequestOutputCollector -from vllm.v1.metrics.stats import IterationStats +from vllm.v1.metrics.stats import IterationStats, SchedulerStats def _ref_convert_id_to_token( @@ -895,6 +902,170 @@ def test_iteration_stats(dummy_test_vectors): assert iteration_stats.num_generation_tokens == num_active +@pytest.mark.parametrize("log_stats", [True, False]) +def test_lora_request_tracking(log_stats: bool, dummy_test_vectors): + """Test LoRA request lifecycle tracking through waiting -> running -> finished.""" + output_processor = OutputProcessor( + dummy_test_vectors.tokenizer, log_stats=log_stats + ) + engine_core = MockEngineCore(dummy_test_vectors.generation_tokens) + engine_core_timestamp = time.monotonic() + + # Create LoRA requests + lora1 = LoRARequest(lora_name="lora-1", lora_int_id=1, lora_path="/path/to/lora1") + lora2 = LoRARequest(lora_name="lora-2", lora_int_id=2, lora_path="/path/to/lora2") + + # Create requests with different LoRA adapters: + # - request-0: lora-1 + # - request-1: lora-2 + # - request-2: None (no LoRA) + lora_assignments = [lora1, lora2, None] + requests = [ + EngineCoreRequest( + request_id=f"request-{idx}", + prompt_token_ids=prompt_tokens, + mm_features=None, + eos_token_id=None, + arrival_time=0, + lora_request=lora_assignments[idx], + cache_salt=None, + data_parallel_rank=None, + sampling_params=SamplingParams(), + pooling_params=None, + ) + for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens) + ] + + # Add all requests to the OutputProcessor + for request in requests: + output_processor.add_request(request, None) + + # First iteration: process outputs with QUEUED events + outputs = EngineCoreOutputs( + outputs=engine_core.get_outputs(), scheduler_stats=SchedulerStats() + ) + for output in outputs.outputs: + output.events = [ + EngineCoreEvent.new_event(EngineCoreEventType.QUEUED, engine_core_timestamp) + ] + + iteration_stats = IterationStats() if log_stats else None + output_processor.process_outputs( + outputs.outputs, engine_core_timestamp, iteration_stats + ) + output_processor.update_scheduler_stats(outputs.scheduler_stats) + + if log_stats: + # Verify waiting counts + assert outputs.scheduler_stats.waiting_lora_adapters.get("lora-1") == 1 + assert outputs.scheduler_stats.waiting_lora_adapters.get("lora-2") == 1 + assert outputs.scheduler_stats.running_lora_adapters.get("lora-1") == 0 + assert outputs.scheduler_stats.running_lora_adapters.get("lora-2") == 0 + # Verify internal state + assert len(output_processor.lora_states.requests) == 2 + assert "lora-1" in output_processor.lora_states.requests + assert "lora-2" in output_processor.lora_states.requests + else: + # When log_stats=False, no tracking should occur + assert iteration_stats is None + assert len(output_processor.lora_states.requests) == 0 + + # Second iteration: process outputs with SCHEDULED events + outputs = EngineCoreOutputs( + outputs=engine_core.get_outputs(), scheduler_stats=SchedulerStats() + ) + for output in outputs.outputs: + output.events = [ + EngineCoreEvent.new_event( + EngineCoreEventType.SCHEDULED, engine_core_timestamp + ) + ] + + iteration_stats = IterationStats() if log_stats else None + output_processor.process_outputs( + outputs.outputs, engine_core_timestamp, iteration_stats + ) + output_processor.update_scheduler_stats(outputs.scheduler_stats) + + if log_stats: + # Verify running counts + assert outputs.scheduler_stats.waiting_lora_adapters.get("lora-1") == 0 + assert outputs.scheduler_stats.waiting_lora_adapters.get("lora-2") == 0 + assert outputs.scheduler_stats.running_lora_adapters.get("lora-1") == 1 + assert outputs.scheduler_stats.running_lora_adapters.get("lora-2") == 1 + else: + assert iteration_stats is None + assert len(output_processor.lora_states.requests) == 0 + + # Third iteration: finish request-0 (lora-1) + outputs = EngineCoreOutputs( + outputs=engine_core.get_outputs(), scheduler_stats=SchedulerStats() + ) + # Find and mark request-0 as finished (it uses lora-1) + for output in outputs.outputs: + if output.request_id == "request-0": + output.finish_reason = FinishReason.LENGTH + break + + iteration_stats = IterationStats() if log_stats else None + output_processor.process_outputs( + outputs.outputs, engine_core_timestamp, iteration_stats + ) + output_processor.update_scheduler_stats(outputs.scheduler_stats) + + if log_stats: + # lora-1 should be removed since no requests remain + assert "lora-1" not in output_processor.lora_states.requests + # lora-2 should still be running + assert outputs.scheduler_stats.running_lora_adapters.get("lora-2") == 1 + assert len(output_processor.lora_states.requests) == 1 + else: + assert len(output_processor.lora_states.requests) == 0 + + # Fourth iteration: finish request-1 (lora-2) + outputs = EngineCoreOutputs( + outputs=engine_core.get_outputs(), scheduler_stats=SchedulerStats() + ) + # Find and mark request-1 as finished (it uses lora-2) + for output in outputs.outputs: + if output.request_id == "request-1": + output.finish_reason = FinishReason.LENGTH + break + + iteration_stats = IterationStats() if log_stats else None + output_processor.process_outputs( + outputs.outputs, engine_core_timestamp, iteration_stats + ) + output_processor.update_scheduler_stats(outputs.scheduler_stats) + + if log_stats: + # lora-2 should be removed since no requests remain + assert "lora-2" not in output_processor.lora_states.requests + assert len(outputs.scheduler_stats.running_lora_adapters) == 0 + assert len(output_processor.lora_states.requests) == 0 + else: + assert len(output_processor.lora_states.requests) == 0 + + # Finish the last request (no LoRA) + outputs = EngineCoreOutputs( + outputs=engine_core.get_outputs(), scheduler_stats=SchedulerStats() + ) + # Find and mark request-2 as finished (it has no LoRA) + for output in outputs.outputs: + if output.request_id == "request-2": + output.finish_reason = FinishReason.LENGTH + break + + iteration_stats = IterationStats() if log_stats else None + output_processor.process_outputs( + outputs.outputs, engine_core_timestamp, iteration_stats + ) + output_processor.update_scheduler_stats(outputs.scheduler_stats) + + # Verify all requests are finished + assert output_processor.get_num_unfinished_requests() == 0 + + @pytest.mark.asyncio async def test_request_output_collector(): NUM_REQS = 3 diff --git a/tests/v1/metrics/test_stats.py b/tests/v1/metrics/test_stats.py index b12e97a875f84..48067def8357e 100644 --- a/tests/v1/metrics/test_stats.py +++ b/tests/v1/metrics/test_stats.py @@ -5,20 +5,4 @@ from vllm.v1.metrics.stats import IterationStats def test_iteration_stats_repr(): iteration_stats = IterationStats() - iteration_stats.iteration_timestamp = 0 - expected_repr = ( - "IterationStats(" - "iteration_timestamp=0, " - "num_generation_tokens=0, " - "num_prompt_tokens=0, " - "num_preempted_reqs=0, " - "finished_requests=[], " - "max_num_generation_tokens_iter=[], " - "n_params_iter=[], " - "time_to_first_tokens_iter=[], " - "inter_token_latencies_iter=[], " - "waiting_lora_adapters={}, " - "running_lora_adapters={}, " - "num_corrupted_reqs=0)" - ) - assert repr(iteration_stats) == expected_repr + assert repr(iteration_stats).startswith("IterationStats(") diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index f0d5b77e8e183..aee21fb3fffe7 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -508,6 +508,8 @@ class AsyncLLM(EngineClient): processed_outputs.reqs_to_abort ) + output_processor.update_scheduler_stats(outputs.scheduler_stats) + # 4) Logging. # TODO(rob): make into a coroutine and launch it in # background thread once Prometheus overhead is non-trivial. diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index 995642a8356fc..e32c74aff313a 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -289,6 +289,7 @@ class LLMEngine: engine_core_timestamp=outputs.timestamp, iteration_stats=iteration_stats, ) + self.output_processor.update_scheduler_stats(outputs.scheduler_stats) # 3) Abort any reqs that finished due to stop strings. self.engine_core.abort_requests(processed_outputs.reqs_to_abort) diff --git a/vllm/v1/engine/output_processor.py b/vllm/v1/engine/output_processor.py index 07c8113dd9b33..d8d03f19d4663 100644 --- a/vllm/v1/engine/output_processor.py +++ b/vllm/v1/engine/output_processor.py @@ -22,7 +22,12 @@ from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest, FinishReason from vllm.v1.engine.detokenizer import IncrementalDetokenizer from vllm.v1.engine.logprobs import LogprobsProcessor from vllm.v1.engine.parallel_sampling import ParentRequest -from vllm.v1.metrics.stats import IterationStats, LoRARequestStates, RequestStateStats +from vllm.v1.metrics.stats import ( + IterationStats, + LoRARequestStates, + RequestStateStats, + SchedulerStats, +) class RequestOutputCollector: @@ -310,7 +315,7 @@ class OutputProcessor: self.tokenizer = tokenizer self.request_states: dict[str, RequestState] = {} self.parent_requests: dict[str, ParentRequest] = {} - self.lora_states = LoRARequestStates() + self.lora_states = LoRARequestStates(log_stats) self.tracer: Tracer | None = None def get_num_unfinished_requests(self): @@ -334,7 +339,7 @@ class OutputProcessor: for request_id in request_ids: req_state = self.request_states.pop(request_id, None) if req_state is not None: - self.lora_states.abort_request(req_state) + self.lora_states.request_finished(request_id, req_state.lora_name) request_ids_to_abort.append(request_id) # Produce final abort output. if req_state.queue is not None and ( @@ -382,7 +387,6 @@ class OutputProcessor: log_stats=self.log_stats, ) self.request_states[request_id] = req_state - self.lora_states.add_request(req_state) if parent_req: self.parent_requests[parent_req.request_id] = parent_req @@ -484,13 +488,15 @@ class OutputProcessor: ) if self.tracer: self.do_tracing(engine_core_output, req_state, iteration_stats) - self.lora_states.update_iteration_stats(iteration_stats) return OutputProcessorOutput( request_outputs=request_outputs, reqs_to_abort=reqs_to_abort, ) + def update_scheduler_stats(self, scheduler_stats: SchedulerStats | None): + self.lora_states.update_scheduler_stats(scheduler_stats) + def do_tracing( self, engine_core_output: EngineCoreOutput, @@ -564,8 +570,6 @@ class OutputProcessor: if iteration_stats is None: return - lora_stats = self.lora_states.get_stats(req_state) - assert engine_core_timestamp is not None assert req_state.stats is not None iteration_stats.update_from_output( @@ -574,7 +578,8 @@ class OutputProcessor: req_state.is_prefilling, req_state.prompt_len, req_state.stats, - lora_stats, + self.lora_states, + req_state.lora_name, ) def _update_stats_from_finished( @@ -596,7 +601,7 @@ class OutputProcessor: max_tokens_param=req_state.max_tokens_param, req_stats=req_state.stats, ) - self.lora_states.finish_request(req_state) + self.lora_states.request_finished(req_state.request_id, req_state.lora_name) ParentRequest.observe_finished_request( req_state.parent_req, iteration_stats, req_state.stats.num_generation_tokens diff --git a/vllm/v1/metrics/loggers.py b/vllm/v1/metrics/loggers.py index eb113c74a22a9..1a175e9e110bd 100644 --- a/vllm/v1/metrics/loggers.py +++ b/vllm/v1/metrics/loggers.py @@ -989,6 +989,20 @@ class PrometheusStatLogger(AggregateStatLoggerBase): scheduler_stats.kv_connector_stats, engine_idx ) + if self.gauge_lora_info is not None: + running_lora_adapters = ",".join( + scheduler_stats.running_lora_adapters.keys() + ) + waiting_lora_adapters = ",".join( + scheduler_stats.waiting_lora_adapters.keys() + ) + lora_info_labels = { + self.labelname_running_lora_adapters: running_lora_adapters, + self.labelname_waiting_lora_adapters: waiting_lora_adapters, + self.labelname_max_lora: self.max_lora, + } + self.gauge_lora_info.labels(**lora_info_labels).set_to_current_time() + if mm_cache_stats is not None: self.counter_mm_cache_queries[engine_idx].inc(mm_cache_stats.queries) self.counter_mm_cache_hits[engine_idx].inc(mm_cache_stats.hits) @@ -1055,20 +1069,6 @@ class PrometheusStatLogger(AggregateStatLoggerBase): finished_request.max_tokens_param ) - if self.gauge_lora_info is not None: - running_lora_adapters = ",".join( - iteration_stats.running_lora_adapters.keys() - ) - waiting_lora_adapters = ",".join( - iteration_stats.waiting_lora_adapters.keys() - ) - lora_info_labels = { - self.labelname_running_lora_adapters: running_lora_adapters, - self.labelname_waiting_lora_adapters: waiting_lora_adapters, - self.labelname_max_lora: self.max_lora, - } - self.gauge_lora_info.labels(**lora_info_labels).set_to_current_time() - def record_sleep_state(self, sleep: int = 0, level: int = 0): awake = 1 discard_all = 0 diff --git a/vllm/v1/metrics/stats.py b/vllm/v1/metrics/stats.py index c5f06a66e21e6..4e9db98db0bc2 100644 --- a/vllm/v1/metrics/stats.py +++ b/vllm/v1/metrics/stats.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import time -from collections import deque +from collections import defaultdict, deque from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any @@ -11,7 +11,6 @@ from vllm.v1.spec_decode.metrics import SpecDecodingStats if TYPE_CHECKING: from vllm.v1.engine import EngineCoreEvent, EngineCoreOutput, FinishReason - from vllm.v1.engine.output_processor import RequestState @dataclass @@ -170,11 +169,8 @@ class SchedulerStats: spec_decoding_stats: SpecDecodingStats | None = None kv_connector_stats: dict[str, Any] | None = None - -@dataclass -class LoRAStats: - waiting_requests: set[str] = field(default_factory=set) - running_requests: set[str] = field(default_factory=set) + waiting_lora_adapters: dict[str, int] = field(default_factory=dict) + running_lora_adapters: dict[str, int] = field(default_factory=dict) @dataclass @@ -229,8 +225,6 @@ class IterationStats: self.n_params_iter: list[int] = [] self.time_to_first_tokens_iter: list[float] = [] self.inter_token_latencies_iter: list[float] = [] - self.waiting_lora_adapters: dict[str, int] = {} - self.running_lora_adapters: dict[str, int] = {} self.num_corrupted_reqs: int = 0 def __repr__(self) -> str: @@ -248,7 +242,8 @@ class IterationStats: is_prefilling: bool, prompt_len: int, req_stats: RequestStateStats, - lora_stats: LoRAStats | None, + lora_states: "LoRARequestStates", + lora_name: str | None, ): num_new_generation_tokens = len(output.new_token_ids) @@ -274,7 +269,12 @@ class IterationStats: # Process request-level engine core events if output.events is not None: self.update_from_events( - output.request_id, output.events, is_prefilling, req_stats, lora_stats + output.request_id, + output.events, + is_prefilling, + req_stats, + lora_states, + lora_name, ) # Process the batch-level "new tokens" engine core event @@ -292,7 +292,8 @@ class IterationStats: events: list["EngineCoreEvent"], is_prefilling: bool, req_stats: RequestStateStats, - lora_stats: LoRAStats | None, + lora_states: "LoRARequestStates", + lora_name: str | None, ): # Avoid circular dependency from vllm.v1.engine import EngineCoreEventType @@ -300,15 +301,14 @@ class IterationStats: for event in events: if event.type == EngineCoreEventType.QUEUED: req_stats.queued_ts = event.timestamp - if lora_stats is not None: - lora_stats.waiting_requests.add(req_id) + lora_states.request_waiting(req_id, lora_name) elif event.type == EngineCoreEventType.SCHEDULED: if req_stats.scheduled_ts == 0.0: # ignore preemptions req_stats.scheduled_ts = event.timestamp - LoRARequestStates.scheduled_request(lora_stats, req_id) + lora_states.request_running(req_id, lora_name) elif event.type == EngineCoreEventType.PREEMPTED: self.num_preempted_reqs += 1 - LoRARequestStates.preempted_request(lora_stats, req_id) + lora_states.request_waiting(req_id, lora_name) def update_from_finished_request( self, @@ -361,61 +361,60 @@ class IterationStats: self.num_corrupted_reqs += 1 -class LoRARequestStates: - """Per-LoRA request state stats.""" +class LoRAStats: + """Tracks waiting and running request IDs for a single LoRA.""" def __init__(self): - self.lora_name_to_stats: dict[str, LoRAStats] = {} + self.waiting: set[str] = set() + self.running: set[str] = set() - def get_stats(self, req_state: "RequestState") -> LoRAStats | None: - if req_state.lora_name is None: - return None - if req_state.lora_name not in self.lora_name_to_stats: - self.lora_name_to_stats[req_state.lora_name] = LoRAStats() - return self.lora_name_to_stats[req_state.lora_name] + def update(self, req_id: str, waiting: bool, running: bool): + assert not (waiting and running) + if waiting: + self.waiting.add(req_id) + else: + self.waiting.discard(req_id) - def add_request(self, req_state: "RequestState"): - if (lora_stats := self.get_stats(req_state)) is not None: - lora_stats.waiting_requests.add(req_state.request_id) + if running: + self.running.add(req_id) + else: + self.running.discard(req_id) - def finish_request(self, req_state: "RequestState"): - if req_state.lora_name is None: + @property + def empty(self) -> bool: + return not (self.waiting or self.running) + + +class LoRARequestStates: + """A per-LoRA count of running and waiting requests.""" + + def __init__(self, log_stats: bool = False): + self.log_stats = log_stats + self.requests: defaultdict[str, LoRAStats] = defaultdict(LoRAStats) + + def _request_update( + self, req_id: str, lora_name: str | None, waiting: bool, running: bool + ): + if not self.log_stats or lora_name is None: return - lora_stats = self.lora_name_to_stats[req_state.lora_name] - lora_stats.running_requests.remove(req_state.request_id) - def abort_request(self, req_state: "RequestState"): - if req_state.lora_name is None: - return - lora_stats = self.lora_name_to_stats[req_state.lora_name] - lora_stats.waiting_requests.discard(req_state.request_id) - lora_stats.running_requests.discard(req_state.request_id) + lora_stats = self.requests[lora_name] + lora_stats.update(req_id, waiting, running) + if lora_stats.empty: + del self.requests[lora_name] - # Break the pattern for this lifecycle methods so we can - # call this from IterationStats.update_from_events() - @staticmethod - def scheduled_request(lora_stats: LoRAStats | None, request_id: str): - if lora_stats is None: - return - lora_stats.waiting_requests.remove(request_id) - lora_stats.running_requests.add(request_id) + def request_waiting(self, req_id: str, lora_name: str | None): + self._request_update(req_id, lora_name, waiting=True, running=False) - @staticmethod - def preempted_request(lora_stats: LoRAStats | None, request_id: str): - if lora_stats is None: - return - lora_stats.running_requests.remove(request_id) - lora_stats.waiting_requests.add(request_id) + def request_running(self, req_id: str, lora_name: str | None): + self._request_update(req_id, lora_name, waiting=False, running=True) - def update_iteration_stats(self, iteration_stats: IterationStats | None): - if iteration_stats is None: + def request_finished(self, req_id: str, lora_name: str | None): + self._request_update(req_id, lora_name, waiting=False, running=False) + + def update_scheduler_stats(self, scheduler_stats: SchedulerStats | None): + if not self.log_stats or scheduler_stats is None: return - for lora_name, stats in self.lora_name_to_stats.items(): - if stats.waiting_requests: - iteration_stats.waiting_lora_adapters[lora_name] = len( - stats.waiting_requests - ) - if stats.running_requests: - iteration_stats.running_lora_adapters[lora_name] = len( - stats.running_requests - ) + for lora_name, stats in self.requests.items(): + scheduler_stats.waiting_lora_adapters[lora_name] = len(stats.waiting) + scheduler_stats.running_lora_adapters[lora_name] = len(stats.running)