[Metrics] Refactor LoRA state tracking (#26801)

Signed-off-by: Mark McLoughlin <markmc@redhat.com>
This commit is contained in:
Mark McLoughlin 2025-11-10 08:34:36 +00:00 committed by GitHub
parent a98cc35c34
commit 6f7de33bed
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 267 additions and 105 deletions

View File

@ -15,12 +15,19 @@ from tests.v1.engine.utils import (
) )
from vllm import PoolingParams from vllm import PoolingParams
from vllm.logprobs import PromptLogprobs, SampleLogprobs from vllm.logprobs import PromptLogprobs, SampleLogprobs
from vllm.lora.request import LoRARequest
from vllm.outputs import CompletionOutput, RequestOutput from vllm.outputs import CompletionOutput, RequestOutput
from vllm.sampling_params import RequestOutputKind, SamplingParams from vllm.sampling_params import RequestOutputKind, SamplingParams
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.v1.engine import EngineCoreRequest from vllm.v1.engine import (
EngineCoreEvent,
EngineCoreEventType,
EngineCoreOutputs,
EngineCoreRequest,
FinishReason,
)
from vllm.v1.engine.output_processor import OutputProcessor, RequestOutputCollector from vllm.v1.engine.output_processor import OutputProcessor, RequestOutputCollector
from vllm.v1.metrics.stats import IterationStats from vllm.v1.metrics.stats import IterationStats, SchedulerStats
def _ref_convert_id_to_token( def _ref_convert_id_to_token(
@ -895,6 +902,170 @@ def test_iteration_stats(dummy_test_vectors):
assert iteration_stats.num_generation_tokens == num_active assert iteration_stats.num_generation_tokens == num_active
@pytest.mark.parametrize("log_stats", [True, False])
def test_lora_request_tracking(log_stats: bool, dummy_test_vectors):
"""Test LoRA request lifecycle tracking through waiting -> running -> finished."""
output_processor = OutputProcessor(
dummy_test_vectors.tokenizer, log_stats=log_stats
)
engine_core = MockEngineCore(dummy_test_vectors.generation_tokens)
engine_core_timestamp = time.monotonic()
# Create LoRA requests
lora1 = LoRARequest(lora_name="lora-1", lora_int_id=1, lora_path="/path/to/lora1")
lora2 = LoRARequest(lora_name="lora-2", lora_int_id=2, lora_path="/path/to/lora2")
# Create requests with different LoRA adapters:
# - request-0: lora-1
# - request-1: lora-2
# - request-2: None (no LoRA)
lora_assignments = [lora1, lora2, None]
requests = [
EngineCoreRequest(
request_id=f"request-{idx}",
prompt_token_ids=prompt_tokens,
mm_features=None,
eos_token_id=None,
arrival_time=0,
lora_request=lora_assignments[idx],
cache_salt=None,
data_parallel_rank=None,
sampling_params=SamplingParams(),
pooling_params=None,
)
for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens)
]
# Add all requests to the OutputProcessor
for request in requests:
output_processor.add_request(request, None)
# First iteration: process outputs with QUEUED events
outputs = EngineCoreOutputs(
outputs=engine_core.get_outputs(), scheduler_stats=SchedulerStats()
)
for output in outputs.outputs:
output.events = [
EngineCoreEvent.new_event(EngineCoreEventType.QUEUED, engine_core_timestamp)
]
iteration_stats = IterationStats() if log_stats else None
output_processor.process_outputs(
outputs.outputs, engine_core_timestamp, iteration_stats
)
output_processor.update_scheduler_stats(outputs.scheduler_stats)
if log_stats:
# Verify waiting counts
assert outputs.scheduler_stats.waiting_lora_adapters.get("lora-1") == 1
assert outputs.scheduler_stats.waiting_lora_adapters.get("lora-2") == 1
assert outputs.scheduler_stats.running_lora_adapters.get("lora-1") == 0
assert outputs.scheduler_stats.running_lora_adapters.get("lora-2") == 0
# Verify internal state
assert len(output_processor.lora_states.requests) == 2
assert "lora-1" in output_processor.lora_states.requests
assert "lora-2" in output_processor.lora_states.requests
else:
# When log_stats=False, no tracking should occur
assert iteration_stats is None
assert len(output_processor.lora_states.requests) == 0
# Second iteration: process outputs with SCHEDULED events
outputs = EngineCoreOutputs(
outputs=engine_core.get_outputs(), scheduler_stats=SchedulerStats()
)
for output in outputs.outputs:
output.events = [
EngineCoreEvent.new_event(
EngineCoreEventType.SCHEDULED, engine_core_timestamp
)
]
iteration_stats = IterationStats() if log_stats else None
output_processor.process_outputs(
outputs.outputs, engine_core_timestamp, iteration_stats
)
output_processor.update_scheduler_stats(outputs.scheduler_stats)
if log_stats:
# Verify running counts
assert outputs.scheduler_stats.waiting_lora_adapters.get("lora-1") == 0
assert outputs.scheduler_stats.waiting_lora_adapters.get("lora-2") == 0
assert outputs.scheduler_stats.running_lora_adapters.get("lora-1") == 1
assert outputs.scheduler_stats.running_lora_adapters.get("lora-2") == 1
else:
assert iteration_stats is None
assert len(output_processor.lora_states.requests) == 0
# Third iteration: finish request-0 (lora-1)
outputs = EngineCoreOutputs(
outputs=engine_core.get_outputs(), scheduler_stats=SchedulerStats()
)
# Find and mark request-0 as finished (it uses lora-1)
for output in outputs.outputs:
if output.request_id == "request-0":
output.finish_reason = FinishReason.LENGTH
break
iteration_stats = IterationStats() if log_stats else None
output_processor.process_outputs(
outputs.outputs, engine_core_timestamp, iteration_stats
)
output_processor.update_scheduler_stats(outputs.scheduler_stats)
if log_stats:
# lora-1 should be removed since no requests remain
assert "lora-1" not in output_processor.lora_states.requests
# lora-2 should still be running
assert outputs.scheduler_stats.running_lora_adapters.get("lora-2") == 1
assert len(output_processor.lora_states.requests) == 1
else:
assert len(output_processor.lora_states.requests) == 0
# Fourth iteration: finish request-1 (lora-2)
outputs = EngineCoreOutputs(
outputs=engine_core.get_outputs(), scheduler_stats=SchedulerStats()
)
# Find and mark request-1 as finished (it uses lora-2)
for output in outputs.outputs:
if output.request_id == "request-1":
output.finish_reason = FinishReason.LENGTH
break
iteration_stats = IterationStats() if log_stats else None
output_processor.process_outputs(
outputs.outputs, engine_core_timestamp, iteration_stats
)
output_processor.update_scheduler_stats(outputs.scheduler_stats)
if log_stats:
# lora-2 should be removed since no requests remain
assert "lora-2" not in output_processor.lora_states.requests
assert len(outputs.scheduler_stats.running_lora_adapters) == 0
assert len(output_processor.lora_states.requests) == 0
else:
assert len(output_processor.lora_states.requests) == 0
# Finish the last request (no LoRA)
outputs = EngineCoreOutputs(
outputs=engine_core.get_outputs(), scheduler_stats=SchedulerStats()
)
# Find and mark request-2 as finished (it has no LoRA)
for output in outputs.outputs:
if output.request_id == "request-2":
output.finish_reason = FinishReason.LENGTH
break
iteration_stats = IterationStats() if log_stats else None
output_processor.process_outputs(
outputs.outputs, engine_core_timestamp, iteration_stats
)
output_processor.update_scheduler_stats(outputs.scheduler_stats)
# Verify all requests are finished
assert output_processor.get_num_unfinished_requests() == 0
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_request_output_collector(): async def test_request_output_collector():
NUM_REQS = 3 NUM_REQS = 3

View File

@ -5,20 +5,4 @@ from vllm.v1.metrics.stats import IterationStats
def test_iteration_stats_repr(): def test_iteration_stats_repr():
iteration_stats = IterationStats() iteration_stats = IterationStats()
iteration_stats.iteration_timestamp = 0 assert repr(iteration_stats).startswith("IterationStats(")
expected_repr = (
"IterationStats("
"iteration_timestamp=0, "
"num_generation_tokens=0, "
"num_prompt_tokens=0, "
"num_preempted_reqs=0, "
"finished_requests=[], "
"max_num_generation_tokens_iter=[], "
"n_params_iter=[], "
"time_to_first_tokens_iter=[], "
"inter_token_latencies_iter=[], "
"waiting_lora_adapters={}, "
"running_lora_adapters={}, "
"num_corrupted_reqs=0)"
)
assert repr(iteration_stats) == expected_repr

View File

@ -508,6 +508,8 @@ class AsyncLLM(EngineClient):
processed_outputs.reqs_to_abort processed_outputs.reqs_to_abort
) )
output_processor.update_scheduler_stats(outputs.scheduler_stats)
# 4) Logging. # 4) Logging.
# TODO(rob): make into a coroutine and launch it in # TODO(rob): make into a coroutine and launch it in
# background thread once Prometheus overhead is non-trivial. # background thread once Prometheus overhead is non-trivial.

View File

@ -289,6 +289,7 @@ class LLMEngine:
engine_core_timestamp=outputs.timestamp, engine_core_timestamp=outputs.timestamp,
iteration_stats=iteration_stats, iteration_stats=iteration_stats,
) )
self.output_processor.update_scheduler_stats(outputs.scheduler_stats)
# 3) Abort any reqs that finished due to stop strings. # 3) Abort any reqs that finished due to stop strings.
self.engine_core.abort_requests(processed_outputs.reqs_to_abort) self.engine_core.abort_requests(processed_outputs.reqs_to_abort)

View File

@ -22,7 +22,12 @@ from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest, FinishReason
from vllm.v1.engine.detokenizer import IncrementalDetokenizer from vllm.v1.engine.detokenizer import IncrementalDetokenizer
from vllm.v1.engine.logprobs import LogprobsProcessor from vllm.v1.engine.logprobs import LogprobsProcessor
from vllm.v1.engine.parallel_sampling import ParentRequest from vllm.v1.engine.parallel_sampling import ParentRequest
from vllm.v1.metrics.stats import IterationStats, LoRARequestStates, RequestStateStats from vllm.v1.metrics.stats import (
IterationStats,
LoRARequestStates,
RequestStateStats,
SchedulerStats,
)
class RequestOutputCollector: class RequestOutputCollector:
@ -310,7 +315,7 @@ class OutputProcessor:
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.request_states: dict[str, RequestState] = {} self.request_states: dict[str, RequestState] = {}
self.parent_requests: dict[str, ParentRequest] = {} self.parent_requests: dict[str, ParentRequest] = {}
self.lora_states = LoRARequestStates() self.lora_states = LoRARequestStates(log_stats)
self.tracer: Tracer | None = None self.tracer: Tracer | None = None
def get_num_unfinished_requests(self): def get_num_unfinished_requests(self):
@ -334,7 +339,7 @@ class OutputProcessor:
for request_id in request_ids: for request_id in request_ids:
req_state = self.request_states.pop(request_id, None) req_state = self.request_states.pop(request_id, None)
if req_state is not None: if req_state is not None:
self.lora_states.abort_request(req_state) self.lora_states.request_finished(request_id, req_state.lora_name)
request_ids_to_abort.append(request_id) request_ids_to_abort.append(request_id)
# Produce final abort output. # Produce final abort output.
if req_state.queue is not None and ( if req_state.queue is not None and (
@ -382,7 +387,6 @@ class OutputProcessor:
log_stats=self.log_stats, log_stats=self.log_stats,
) )
self.request_states[request_id] = req_state self.request_states[request_id] = req_state
self.lora_states.add_request(req_state)
if parent_req: if parent_req:
self.parent_requests[parent_req.request_id] = parent_req self.parent_requests[parent_req.request_id] = parent_req
@ -484,13 +488,15 @@ class OutputProcessor:
) )
if self.tracer: if self.tracer:
self.do_tracing(engine_core_output, req_state, iteration_stats) self.do_tracing(engine_core_output, req_state, iteration_stats)
self.lora_states.update_iteration_stats(iteration_stats)
return OutputProcessorOutput( return OutputProcessorOutput(
request_outputs=request_outputs, request_outputs=request_outputs,
reqs_to_abort=reqs_to_abort, reqs_to_abort=reqs_to_abort,
) )
def update_scheduler_stats(self, scheduler_stats: SchedulerStats | None):
self.lora_states.update_scheduler_stats(scheduler_stats)
def do_tracing( def do_tracing(
self, self,
engine_core_output: EngineCoreOutput, engine_core_output: EngineCoreOutput,
@ -564,8 +570,6 @@ class OutputProcessor:
if iteration_stats is None: if iteration_stats is None:
return return
lora_stats = self.lora_states.get_stats(req_state)
assert engine_core_timestamp is not None assert engine_core_timestamp is not None
assert req_state.stats is not None assert req_state.stats is not None
iteration_stats.update_from_output( iteration_stats.update_from_output(
@ -574,7 +578,8 @@ class OutputProcessor:
req_state.is_prefilling, req_state.is_prefilling,
req_state.prompt_len, req_state.prompt_len,
req_state.stats, req_state.stats,
lora_stats, self.lora_states,
req_state.lora_name,
) )
def _update_stats_from_finished( def _update_stats_from_finished(
@ -596,7 +601,7 @@ class OutputProcessor:
max_tokens_param=req_state.max_tokens_param, max_tokens_param=req_state.max_tokens_param,
req_stats=req_state.stats, req_stats=req_state.stats,
) )
self.lora_states.finish_request(req_state) self.lora_states.request_finished(req_state.request_id, req_state.lora_name)
ParentRequest.observe_finished_request( ParentRequest.observe_finished_request(
req_state.parent_req, iteration_stats, req_state.stats.num_generation_tokens req_state.parent_req, iteration_stats, req_state.stats.num_generation_tokens

View File

@ -989,6 +989,20 @@ class PrometheusStatLogger(AggregateStatLoggerBase):
scheduler_stats.kv_connector_stats, engine_idx scheduler_stats.kv_connector_stats, engine_idx
) )
if self.gauge_lora_info is not None:
running_lora_adapters = ",".join(
scheduler_stats.running_lora_adapters.keys()
)
waiting_lora_adapters = ",".join(
scheduler_stats.waiting_lora_adapters.keys()
)
lora_info_labels = {
self.labelname_running_lora_adapters: running_lora_adapters,
self.labelname_waiting_lora_adapters: waiting_lora_adapters,
self.labelname_max_lora: self.max_lora,
}
self.gauge_lora_info.labels(**lora_info_labels).set_to_current_time()
if mm_cache_stats is not None: if mm_cache_stats is not None:
self.counter_mm_cache_queries[engine_idx].inc(mm_cache_stats.queries) self.counter_mm_cache_queries[engine_idx].inc(mm_cache_stats.queries)
self.counter_mm_cache_hits[engine_idx].inc(mm_cache_stats.hits) self.counter_mm_cache_hits[engine_idx].inc(mm_cache_stats.hits)
@ -1055,20 +1069,6 @@ class PrometheusStatLogger(AggregateStatLoggerBase):
finished_request.max_tokens_param finished_request.max_tokens_param
) )
if self.gauge_lora_info is not None:
running_lora_adapters = ",".join(
iteration_stats.running_lora_adapters.keys()
)
waiting_lora_adapters = ",".join(
iteration_stats.waiting_lora_adapters.keys()
)
lora_info_labels = {
self.labelname_running_lora_adapters: running_lora_adapters,
self.labelname_waiting_lora_adapters: waiting_lora_adapters,
self.labelname_max_lora: self.max_lora,
}
self.gauge_lora_info.labels(**lora_info_labels).set_to_current_time()
def record_sleep_state(self, sleep: int = 0, level: int = 0): def record_sleep_state(self, sleep: int = 0, level: int = 0):
awake = 1 awake = 1
discard_all = 0 discard_all = 0

View File

@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import time import time
from collections import deque from collections import defaultdict, deque
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any
@ -11,7 +11,6 @@ from vllm.v1.spec_decode.metrics import SpecDecodingStats
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.v1.engine import EngineCoreEvent, EngineCoreOutput, FinishReason from vllm.v1.engine import EngineCoreEvent, EngineCoreOutput, FinishReason
from vllm.v1.engine.output_processor import RequestState
@dataclass @dataclass
@ -170,11 +169,8 @@ class SchedulerStats:
spec_decoding_stats: SpecDecodingStats | None = None spec_decoding_stats: SpecDecodingStats | None = None
kv_connector_stats: dict[str, Any] | None = None kv_connector_stats: dict[str, Any] | None = None
waiting_lora_adapters: dict[str, int] = field(default_factory=dict)
@dataclass running_lora_adapters: dict[str, int] = field(default_factory=dict)
class LoRAStats:
waiting_requests: set[str] = field(default_factory=set)
running_requests: set[str] = field(default_factory=set)
@dataclass @dataclass
@ -229,8 +225,6 @@ class IterationStats:
self.n_params_iter: list[int] = [] self.n_params_iter: list[int] = []
self.time_to_first_tokens_iter: list[float] = [] self.time_to_first_tokens_iter: list[float] = []
self.inter_token_latencies_iter: list[float] = [] self.inter_token_latencies_iter: list[float] = []
self.waiting_lora_adapters: dict[str, int] = {}
self.running_lora_adapters: dict[str, int] = {}
self.num_corrupted_reqs: int = 0 self.num_corrupted_reqs: int = 0
def __repr__(self) -> str: def __repr__(self) -> str:
@ -248,7 +242,8 @@ class IterationStats:
is_prefilling: bool, is_prefilling: bool,
prompt_len: int, prompt_len: int,
req_stats: RequestStateStats, req_stats: RequestStateStats,
lora_stats: LoRAStats | None, lora_states: "LoRARequestStates",
lora_name: str | None,
): ):
num_new_generation_tokens = len(output.new_token_ids) num_new_generation_tokens = len(output.new_token_ids)
@ -274,7 +269,12 @@ class IterationStats:
# Process request-level engine core events # Process request-level engine core events
if output.events is not None: if output.events is not None:
self.update_from_events( self.update_from_events(
output.request_id, output.events, is_prefilling, req_stats, lora_stats output.request_id,
output.events,
is_prefilling,
req_stats,
lora_states,
lora_name,
) )
# Process the batch-level "new tokens" engine core event # Process the batch-level "new tokens" engine core event
@ -292,7 +292,8 @@ class IterationStats:
events: list["EngineCoreEvent"], events: list["EngineCoreEvent"],
is_prefilling: bool, is_prefilling: bool,
req_stats: RequestStateStats, req_stats: RequestStateStats,
lora_stats: LoRAStats | None, lora_states: "LoRARequestStates",
lora_name: str | None,
): ):
# Avoid circular dependency # Avoid circular dependency
from vllm.v1.engine import EngineCoreEventType from vllm.v1.engine import EngineCoreEventType
@ -300,15 +301,14 @@ class IterationStats:
for event in events: for event in events:
if event.type == EngineCoreEventType.QUEUED: if event.type == EngineCoreEventType.QUEUED:
req_stats.queued_ts = event.timestamp req_stats.queued_ts = event.timestamp
if lora_stats is not None: lora_states.request_waiting(req_id, lora_name)
lora_stats.waiting_requests.add(req_id)
elif event.type == EngineCoreEventType.SCHEDULED: elif event.type == EngineCoreEventType.SCHEDULED:
if req_stats.scheduled_ts == 0.0: # ignore preemptions if req_stats.scheduled_ts == 0.0: # ignore preemptions
req_stats.scheduled_ts = event.timestamp req_stats.scheduled_ts = event.timestamp
LoRARequestStates.scheduled_request(lora_stats, req_id) lora_states.request_running(req_id, lora_name)
elif event.type == EngineCoreEventType.PREEMPTED: elif event.type == EngineCoreEventType.PREEMPTED:
self.num_preempted_reqs += 1 self.num_preempted_reqs += 1
LoRARequestStates.preempted_request(lora_stats, req_id) lora_states.request_waiting(req_id, lora_name)
def update_from_finished_request( def update_from_finished_request(
self, self,
@ -361,61 +361,60 @@ class IterationStats:
self.num_corrupted_reqs += 1 self.num_corrupted_reqs += 1
class LoRARequestStates: class LoRAStats:
"""Per-LoRA request state stats.""" """Tracks waiting and running request IDs for a single LoRA."""
def __init__(self): def __init__(self):
self.lora_name_to_stats: dict[str, LoRAStats] = {} self.waiting: set[str] = set()
self.running: set[str] = set()
def get_stats(self, req_state: "RequestState") -> LoRAStats | None: def update(self, req_id: str, waiting: bool, running: bool):
if req_state.lora_name is None: assert not (waiting and running)
return None if waiting:
if req_state.lora_name not in self.lora_name_to_stats: self.waiting.add(req_id)
self.lora_name_to_stats[req_state.lora_name] = LoRAStats() else:
return self.lora_name_to_stats[req_state.lora_name] self.waiting.discard(req_id)
def add_request(self, req_state: "RequestState"): if running:
if (lora_stats := self.get_stats(req_state)) is not None: self.running.add(req_id)
lora_stats.waiting_requests.add(req_state.request_id) else:
self.running.discard(req_id)
def finish_request(self, req_state: "RequestState"): @property
if req_state.lora_name is None: def empty(self) -> bool:
return not (self.waiting or self.running)
class LoRARequestStates:
"""A per-LoRA count of running and waiting requests."""
def __init__(self, log_stats: bool = False):
self.log_stats = log_stats
self.requests: defaultdict[str, LoRAStats] = defaultdict(LoRAStats)
def _request_update(
self, req_id: str, lora_name: str | None, waiting: bool, running: bool
):
if not self.log_stats or lora_name is None:
return return
lora_stats = self.lora_name_to_stats[req_state.lora_name]
lora_stats.running_requests.remove(req_state.request_id)
def abort_request(self, req_state: "RequestState"): lora_stats = self.requests[lora_name]
if req_state.lora_name is None: lora_stats.update(req_id, waiting, running)
return if lora_stats.empty:
lora_stats = self.lora_name_to_stats[req_state.lora_name] del self.requests[lora_name]
lora_stats.waiting_requests.discard(req_state.request_id)
lora_stats.running_requests.discard(req_state.request_id)
# Break the pattern for this lifecycle methods so we can def request_waiting(self, req_id: str, lora_name: str | None):
# call this from IterationStats.update_from_events() self._request_update(req_id, lora_name, waiting=True, running=False)
@staticmethod
def scheduled_request(lora_stats: LoRAStats | None, request_id: str):
if lora_stats is None:
return
lora_stats.waiting_requests.remove(request_id)
lora_stats.running_requests.add(request_id)
@staticmethod def request_running(self, req_id: str, lora_name: str | None):
def preempted_request(lora_stats: LoRAStats | None, request_id: str): self._request_update(req_id, lora_name, waiting=False, running=True)
if lora_stats is None:
return
lora_stats.running_requests.remove(request_id)
lora_stats.waiting_requests.add(request_id)
def update_iteration_stats(self, iteration_stats: IterationStats | None): def request_finished(self, req_id: str, lora_name: str | None):
if iteration_stats is None: self._request_update(req_id, lora_name, waiting=False, running=False)
def update_scheduler_stats(self, scheduler_stats: SchedulerStats | None):
if not self.log_stats or scheduler_stats is None:
return return
for lora_name, stats in self.lora_name_to_stats.items(): for lora_name, stats in self.requests.items():
if stats.waiting_requests: scheduler_stats.waiting_lora_adapters[lora_name] = len(stats.waiting)
iteration_stats.waiting_lora_adapters[lora_name] = len( scheduler_stats.running_lora_adapters[lora_name] = len(stats.running)
stats.waiting_requests
)
if stats.running_requests:
iteration_stats.running_lora_adapters[lora_name] = len(
stats.running_requests
)