mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-26 04:54:30 +08:00
[Metrics] Refactor LoRA state tracking (#26801)
Signed-off-by: Mark McLoughlin <markmc@redhat.com>
This commit is contained in:
parent
a98cc35c34
commit
6f7de33bed
@ -15,12 +15,19 @@ from tests.v1.engine.utils import (
|
||||
)
|
||||
from vllm import PoolingParams
|
||||
from vllm.logprobs import PromptLogprobs, SampleLogprobs
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.outputs import CompletionOutput, RequestOutput
|
||||
from vllm.sampling_params import RequestOutputKind, SamplingParams
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.v1.engine import EngineCoreRequest
|
||||
from vllm.v1.engine import (
|
||||
EngineCoreEvent,
|
||||
EngineCoreEventType,
|
||||
EngineCoreOutputs,
|
||||
EngineCoreRequest,
|
||||
FinishReason,
|
||||
)
|
||||
from vllm.v1.engine.output_processor import OutputProcessor, RequestOutputCollector
|
||||
from vllm.v1.metrics.stats import IterationStats
|
||||
from vllm.v1.metrics.stats import IterationStats, SchedulerStats
|
||||
|
||||
|
||||
def _ref_convert_id_to_token(
|
||||
@ -895,6 +902,170 @@ def test_iteration_stats(dummy_test_vectors):
|
||||
assert iteration_stats.num_generation_tokens == num_active
|
||||
|
||||
|
||||
@pytest.mark.parametrize("log_stats", [True, False])
|
||||
def test_lora_request_tracking(log_stats: bool, dummy_test_vectors):
|
||||
"""Test LoRA request lifecycle tracking through waiting -> running -> finished."""
|
||||
output_processor = OutputProcessor(
|
||||
dummy_test_vectors.tokenizer, log_stats=log_stats
|
||||
)
|
||||
engine_core = MockEngineCore(dummy_test_vectors.generation_tokens)
|
||||
engine_core_timestamp = time.monotonic()
|
||||
|
||||
# Create LoRA requests
|
||||
lora1 = LoRARequest(lora_name="lora-1", lora_int_id=1, lora_path="/path/to/lora1")
|
||||
lora2 = LoRARequest(lora_name="lora-2", lora_int_id=2, lora_path="/path/to/lora2")
|
||||
|
||||
# Create requests with different LoRA adapters:
|
||||
# - request-0: lora-1
|
||||
# - request-1: lora-2
|
||||
# - request-2: None (no LoRA)
|
||||
lora_assignments = [lora1, lora2, None]
|
||||
requests = [
|
||||
EngineCoreRequest(
|
||||
request_id=f"request-{idx}",
|
||||
prompt_token_ids=prompt_tokens,
|
||||
mm_features=None,
|
||||
eos_token_id=None,
|
||||
arrival_time=0,
|
||||
lora_request=lora_assignments[idx],
|
||||
cache_salt=None,
|
||||
data_parallel_rank=None,
|
||||
sampling_params=SamplingParams(),
|
||||
pooling_params=None,
|
||||
)
|
||||
for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens)
|
||||
]
|
||||
|
||||
# Add all requests to the OutputProcessor
|
||||
for request in requests:
|
||||
output_processor.add_request(request, None)
|
||||
|
||||
# First iteration: process outputs with QUEUED events
|
||||
outputs = EngineCoreOutputs(
|
||||
outputs=engine_core.get_outputs(), scheduler_stats=SchedulerStats()
|
||||
)
|
||||
for output in outputs.outputs:
|
||||
output.events = [
|
||||
EngineCoreEvent.new_event(EngineCoreEventType.QUEUED, engine_core_timestamp)
|
||||
]
|
||||
|
||||
iteration_stats = IterationStats() if log_stats else None
|
||||
output_processor.process_outputs(
|
||||
outputs.outputs, engine_core_timestamp, iteration_stats
|
||||
)
|
||||
output_processor.update_scheduler_stats(outputs.scheduler_stats)
|
||||
|
||||
if log_stats:
|
||||
# Verify waiting counts
|
||||
assert outputs.scheduler_stats.waiting_lora_adapters.get("lora-1") == 1
|
||||
assert outputs.scheduler_stats.waiting_lora_adapters.get("lora-2") == 1
|
||||
assert outputs.scheduler_stats.running_lora_adapters.get("lora-1") == 0
|
||||
assert outputs.scheduler_stats.running_lora_adapters.get("lora-2") == 0
|
||||
# Verify internal state
|
||||
assert len(output_processor.lora_states.requests) == 2
|
||||
assert "lora-1" in output_processor.lora_states.requests
|
||||
assert "lora-2" in output_processor.lora_states.requests
|
||||
else:
|
||||
# When log_stats=False, no tracking should occur
|
||||
assert iteration_stats is None
|
||||
assert len(output_processor.lora_states.requests) == 0
|
||||
|
||||
# Second iteration: process outputs with SCHEDULED events
|
||||
outputs = EngineCoreOutputs(
|
||||
outputs=engine_core.get_outputs(), scheduler_stats=SchedulerStats()
|
||||
)
|
||||
for output in outputs.outputs:
|
||||
output.events = [
|
||||
EngineCoreEvent.new_event(
|
||||
EngineCoreEventType.SCHEDULED, engine_core_timestamp
|
||||
)
|
||||
]
|
||||
|
||||
iteration_stats = IterationStats() if log_stats else None
|
||||
output_processor.process_outputs(
|
||||
outputs.outputs, engine_core_timestamp, iteration_stats
|
||||
)
|
||||
output_processor.update_scheduler_stats(outputs.scheduler_stats)
|
||||
|
||||
if log_stats:
|
||||
# Verify running counts
|
||||
assert outputs.scheduler_stats.waiting_lora_adapters.get("lora-1") == 0
|
||||
assert outputs.scheduler_stats.waiting_lora_adapters.get("lora-2") == 0
|
||||
assert outputs.scheduler_stats.running_lora_adapters.get("lora-1") == 1
|
||||
assert outputs.scheduler_stats.running_lora_adapters.get("lora-2") == 1
|
||||
else:
|
||||
assert iteration_stats is None
|
||||
assert len(output_processor.lora_states.requests) == 0
|
||||
|
||||
# Third iteration: finish request-0 (lora-1)
|
||||
outputs = EngineCoreOutputs(
|
||||
outputs=engine_core.get_outputs(), scheduler_stats=SchedulerStats()
|
||||
)
|
||||
# Find and mark request-0 as finished (it uses lora-1)
|
||||
for output in outputs.outputs:
|
||||
if output.request_id == "request-0":
|
||||
output.finish_reason = FinishReason.LENGTH
|
||||
break
|
||||
|
||||
iteration_stats = IterationStats() if log_stats else None
|
||||
output_processor.process_outputs(
|
||||
outputs.outputs, engine_core_timestamp, iteration_stats
|
||||
)
|
||||
output_processor.update_scheduler_stats(outputs.scheduler_stats)
|
||||
|
||||
if log_stats:
|
||||
# lora-1 should be removed since no requests remain
|
||||
assert "lora-1" not in output_processor.lora_states.requests
|
||||
# lora-2 should still be running
|
||||
assert outputs.scheduler_stats.running_lora_adapters.get("lora-2") == 1
|
||||
assert len(output_processor.lora_states.requests) == 1
|
||||
else:
|
||||
assert len(output_processor.lora_states.requests) == 0
|
||||
|
||||
# Fourth iteration: finish request-1 (lora-2)
|
||||
outputs = EngineCoreOutputs(
|
||||
outputs=engine_core.get_outputs(), scheduler_stats=SchedulerStats()
|
||||
)
|
||||
# Find and mark request-1 as finished (it uses lora-2)
|
||||
for output in outputs.outputs:
|
||||
if output.request_id == "request-1":
|
||||
output.finish_reason = FinishReason.LENGTH
|
||||
break
|
||||
|
||||
iteration_stats = IterationStats() if log_stats else None
|
||||
output_processor.process_outputs(
|
||||
outputs.outputs, engine_core_timestamp, iteration_stats
|
||||
)
|
||||
output_processor.update_scheduler_stats(outputs.scheduler_stats)
|
||||
|
||||
if log_stats:
|
||||
# lora-2 should be removed since no requests remain
|
||||
assert "lora-2" not in output_processor.lora_states.requests
|
||||
assert len(outputs.scheduler_stats.running_lora_adapters) == 0
|
||||
assert len(output_processor.lora_states.requests) == 0
|
||||
else:
|
||||
assert len(output_processor.lora_states.requests) == 0
|
||||
|
||||
# Finish the last request (no LoRA)
|
||||
outputs = EngineCoreOutputs(
|
||||
outputs=engine_core.get_outputs(), scheduler_stats=SchedulerStats()
|
||||
)
|
||||
# Find and mark request-2 as finished (it has no LoRA)
|
||||
for output in outputs.outputs:
|
||||
if output.request_id == "request-2":
|
||||
output.finish_reason = FinishReason.LENGTH
|
||||
break
|
||||
|
||||
iteration_stats = IterationStats() if log_stats else None
|
||||
output_processor.process_outputs(
|
||||
outputs.outputs, engine_core_timestamp, iteration_stats
|
||||
)
|
||||
output_processor.update_scheduler_stats(outputs.scheduler_stats)
|
||||
|
||||
# Verify all requests are finished
|
||||
assert output_processor.get_num_unfinished_requests() == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_request_output_collector():
|
||||
NUM_REQS = 3
|
||||
|
||||
@ -5,20 +5,4 @@ from vllm.v1.metrics.stats import IterationStats
|
||||
|
||||
def test_iteration_stats_repr():
|
||||
iteration_stats = IterationStats()
|
||||
iteration_stats.iteration_timestamp = 0
|
||||
expected_repr = (
|
||||
"IterationStats("
|
||||
"iteration_timestamp=0, "
|
||||
"num_generation_tokens=0, "
|
||||
"num_prompt_tokens=0, "
|
||||
"num_preempted_reqs=0, "
|
||||
"finished_requests=[], "
|
||||
"max_num_generation_tokens_iter=[], "
|
||||
"n_params_iter=[], "
|
||||
"time_to_first_tokens_iter=[], "
|
||||
"inter_token_latencies_iter=[], "
|
||||
"waiting_lora_adapters={}, "
|
||||
"running_lora_adapters={}, "
|
||||
"num_corrupted_reqs=0)"
|
||||
)
|
||||
assert repr(iteration_stats) == expected_repr
|
||||
assert repr(iteration_stats).startswith("IterationStats(")
|
||||
|
||||
@ -508,6 +508,8 @@ class AsyncLLM(EngineClient):
|
||||
processed_outputs.reqs_to_abort
|
||||
)
|
||||
|
||||
output_processor.update_scheduler_stats(outputs.scheduler_stats)
|
||||
|
||||
# 4) Logging.
|
||||
# TODO(rob): make into a coroutine and launch it in
|
||||
# background thread once Prometheus overhead is non-trivial.
|
||||
|
||||
@ -289,6 +289,7 @@ class LLMEngine:
|
||||
engine_core_timestamp=outputs.timestamp,
|
||||
iteration_stats=iteration_stats,
|
||||
)
|
||||
self.output_processor.update_scheduler_stats(outputs.scheduler_stats)
|
||||
|
||||
# 3) Abort any reqs that finished due to stop strings.
|
||||
self.engine_core.abort_requests(processed_outputs.reqs_to_abort)
|
||||
|
||||
@ -22,7 +22,12 @@ from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest, FinishReason
|
||||
from vllm.v1.engine.detokenizer import IncrementalDetokenizer
|
||||
from vllm.v1.engine.logprobs import LogprobsProcessor
|
||||
from vllm.v1.engine.parallel_sampling import ParentRequest
|
||||
from vllm.v1.metrics.stats import IterationStats, LoRARequestStates, RequestStateStats
|
||||
from vllm.v1.metrics.stats import (
|
||||
IterationStats,
|
||||
LoRARequestStates,
|
||||
RequestStateStats,
|
||||
SchedulerStats,
|
||||
)
|
||||
|
||||
|
||||
class RequestOutputCollector:
|
||||
@ -310,7 +315,7 @@ class OutputProcessor:
|
||||
self.tokenizer = tokenizer
|
||||
self.request_states: dict[str, RequestState] = {}
|
||||
self.parent_requests: dict[str, ParentRequest] = {}
|
||||
self.lora_states = LoRARequestStates()
|
||||
self.lora_states = LoRARequestStates(log_stats)
|
||||
self.tracer: Tracer | None = None
|
||||
|
||||
def get_num_unfinished_requests(self):
|
||||
@ -334,7 +339,7 @@ class OutputProcessor:
|
||||
for request_id in request_ids:
|
||||
req_state = self.request_states.pop(request_id, None)
|
||||
if req_state is not None:
|
||||
self.lora_states.abort_request(req_state)
|
||||
self.lora_states.request_finished(request_id, req_state.lora_name)
|
||||
request_ids_to_abort.append(request_id)
|
||||
# Produce final abort output.
|
||||
if req_state.queue is not None and (
|
||||
@ -382,7 +387,6 @@ class OutputProcessor:
|
||||
log_stats=self.log_stats,
|
||||
)
|
||||
self.request_states[request_id] = req_state
|
||||
self.lora_states.add_request(req_state)
|
||||
if parent_req:
|
||||
self.parent_requests[parent_req.request_id] = parent_req
|
||||
|
||||
@ -484,13 +488,15 @@ class OutputProcessor:
|
||||
)
|
||||
if self.tracer:
|
||||
self.do_tracing(engine_core_output, req_state, iteration_stats)
|
||||
self.lora_states.update_iteration_stats(iteration_stats)
|
||||
|
||||
return OutputProcessorOutput(
|
||||
request_outputs=request_outputs,
|
||||
reqs_to_abort=reqs_to_abort,
|
||||
)
|
||||
|
||||
def update_scheduler_stats(self, scheduler_stats: SchedulerStats | None):
|
||||
self.lora_states.update_scheduler_stats(scheduler_stats)
|
||||
|
||||
def do_tracing(
|
||||
self,
|
||||
engine_core_output: EngineCoreOutput,
|
||||
@ -564,8 +570,6 @@ class OutputProcessor:
|
||||
if iteration_stats is None:
|
||||
return
|
||||
|
||||
lora_stats = self.lora_states.get_stats(req_state)
|
||||
|
||||
assert engine_core_timestamp is not None
|
||||
assert req_state.stats is not None
|
||||
iteration_stats.update_from_output(
|
||||
@ -574,7 +578,8 @@ class OutputProcessor:
|
||||
req_state.is_prefilling,
|
||||
req_state.prompt_len,
|
||||
req_state.stats,
|
||||
lora_stats,
|
||||
self.lora_states,
|
||||
req_state.lora_name,
|
||||
)
|
||||
|
||||
def _update_stats_from_finished(
|
||||
@ -596,7 +601,7 @@ class OutputProcessor:
|
||||
max_tokens_param=req_state.max_tokens_param,
|
||||
req_stats=req_state.stats,
|
||||
)
|
||||
self.lora_states.finish_request(req_state)
|
||||
self.lora_states.request_finished(req_state.request_id, req_state.lora_name)
|
||||
|
||||
ParentRequest.observe_finished_request(
|
||||
req_state.parent_req, iteration_stats, req_state.stats.num_generation_tokens
|
||||
|
||||
@ -989,6 +989,20 @@ class PrometheusStatLogger(AggregateStatLoggerBase):
|
||||
scheduler_stats.kv_connector_stats, engine_idx
|
||||
)
|
||||
|
||||
if self.gauge_lora_info is not None:
|
||||
running_lora_adapters = ",".join(
|
||||
scheduler_stats.running_lora_adapters.keys()
|
||||
)
|
||||
waiting_lora_adapters = ",".join(
|
||||
scheduler_stats.waiting_lora_adapters.keys()
|
||||
)
|
||||
lora_info_labels = {
|
||||
self.labelname_running_lora_adapters: running_lora_adapters,
|
||||
self.labelname_waiting_lora_adapters: waiting_lora_adapters,
|
||||
self.labelname_max_lora: self.max_lora,
|
||||
}
|
||||
self.gauge_lora_info.labels(**lora_info_labels).set_to_current_time()
|
||||
|
||||
if mm_cache_stats is not None:
|
||||
self.counter_mm_cache_queries[engine_idx].inc(mm_cache_stats.queries)
|
||||
self.counter_mm_cache_hits[engine_idx].inc(mm_cache_stats.hits)
|
||||
@ -1055,20 +1069,6 @@ class PrometheusStatLogger(AggregateStatLoggerBase):
|
||||
finished_request.max_tokens_param
|
||||
)
|
||||
|
||||
if self.gauge_lora_info is not None:
|
||||
running_lora_adapters = ",".join(
|
||||
iteration_stats.running_lora_adapters.keys()
|
||||
)
|
||||
waiting_lora_adapters = ",".join(
|
||||
iteration_stats.waiting_lora_adapters.keys()
|
||||
)
|
||||
lora_info_labels = {
|
||||
self.labelname_running_lora_adapters: running_lora_adapters,
|
||||
self.labelname_waiting_lora_adapters: waiting_lora_adapters,
|
||||
self.labelname_max_lora: self.max_lora,
|
||||
}
|
||||
self.gauge_lora_info.labels(**lora_info_labels).set_to_current_time()
|
||||
|
||||
def record_sleep_state(self, sleep: int = 0, level: int = 0):
|
||||
awake = 1
|
||||
discard_all = 0
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import time
|
||||
from collections import deque
|
||||
from collections import defaultdict, deque
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
@ -11,7 +11,6 @@ from vllm.v1.spec_decode.metrics import SpecDecodingStats
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.engine import EngineCoreEvent, EngineCoreOutput, FinishReason
|
||||
from vllm.v1.engine.output_processor import RequestState
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -170,11 +169,8 @@ class SchedulerStats:
|
||||
spec_decoding_stats: SpecDecodingStats | None = None
|
||||
kv_connector_stats: dict[str, Any] | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoRAStats:
|
||||
waiting_requests: set[str] = field(default_factory=set)
|
||||
running_requests: set[str] = field(default_factory=set)
|
||||
waiting_lora_adapters: dict[str, int] = field(default_factory=dict)
|
||||
running_lora_adapters: dict[str, int] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -229,8 +225,6 @@ class IterationStats:
|
||||
self.n_params_iter: list[int] = []
|
||||
self.time_to_first_tokens_iter: list[float] = []
|
||||
self.inter_token_latencies_iter: list[float] = []
|
||||
self.waiting_lora_adapters: dict[str, int] = {}
|
||||
self.running_lora_adapters: dict[str, int] = {}
|
||||
self.num_corrupted_reqs: int = 0
|
||||
|
||||
def __repr__(self) -> str:
|
||||
@ -248,7 +242,8 @@ class IterationStats:
|
||||
is_prefilling: bool,
|
||||
prompt_len: int,
|
||||
req_stats: RequestStateStats,
|
||||
lora_stats: LoRAStats | None,
|
||||
lora_states: "LoRARequestStates",
|
||||
lora_name: str | None,
|
||||
):
|
||||
num_new_generation_tokens = len(output.new_token_ids)
|
||||
|
||||
@ -274,7 +269,12 @@ class IterationStats:
|
||||
# Process request-level engine core events
|
||||
if output.events is not None:
|
||||
self.update_from_events(
|
||||
output.request_id, output.events, is_prefilling, req_stats, lora_stats
|
||||
output.request_id,
|
||||
output.events,
|
||||
is_prefilling,
|
||||
req_stats,
|
||||
lora_states,
|
||||
lora_name,
|
||||
)
|
||||
|
||||
# Process the batch-level "new tokens" engine core event
|
||||
@ -292,7 +292,8 @@ class IterationStats:
|
||||
events: list["EngineCoreEvent"],
|
||||
is_prefilling: bool,
|
||||
req_stats: RequestStateStats,
|
||||
lora_stats: LoRAStats | None,
|
||||
lora_states: "LoRARequestStates",
|
||||
lora_name: str | None,
|
||||
):
|
||||
# Avoid circular dependency
|
||||
from vllm.v1.engine import EngineCoreEventType
|
||||
@ -300,15 +301,14 @@ class IterationStats:
|
||||
for event in events:
|
||||
if event.type == EngineCoreEventType.QUEUED:
|
||||
req_stats.queued_ts = event.timestamp
|
||||
if lora_stats is not None:
|
||||
lora_stats.waiting_requests.add(req_id)
|
||||
lora_states.request_waiting(req_id, lora_name)
|
||||
elif event.type == EngineCoreEventType.SCHEDULED:
|
||||
if req_stats.scheduled_ts == 0.0: # ignore preemptions
|
||||
req_stats.scheduled_ts = event.timestamp
|
||||
LoRARequestStates.scheduled_request(lora_stats, req_id)
|
||||
lora_states.request_running(req_id, lora_name)
|
||||
elif event.type == EngineCoreEventType.PREEMPTED:
|
||||
self.num_preempted_reqs += 1
|
||||
LoRARequestStates.preempted_request(lora_stats, req_id)
|
||||
lora_states.request_waiting(req_id, lora_name)
|
||||
|
||||
def update_from_finished_request(
|
||||
self,
|
||||
@ -361,61 +361,60 @@ class IterationStats:
|
||||
self.num_corrupted_reqs += 1
|
||||
|
||||
|
||||
class LoRARequestStates:
|
||||
"""Per-LoRA request state stats."""
|
||||
class LoRAStats:
|
||||
"""Tracks waiting and running request IDs for a single LoRA."""
|
||||
|
||||
def __init__(self):
|
||||
self.lora_name_to_stats: dict[str, LoRAStats] = {}
|
||||
self.waiting: set[str] = set()
|
||||
self.running: set[str] = set()
|
||||
|
||||
def get_stats(self, req_state: "RequestState") -> LoRAStats | None:
|
||||
if req_state.lora_name is None:
|
||||
return None
|
||||
if req_state.lora_name not in self.lora_name_to_stats:
|
||||
self.lora_name_to_stats[req_state.lora_name] = LoRAStats()
|
||||
return self.lora_name_to_stats[req_state.lora_name]
|
||||
def update(self, req_id: str, waiting: bool, running: bool):
|
||||
assert not (waiting and running)
|
||||
if waiting:
|
||||
self.waiting.add(req_id)
|
||||
else:
|
||||
self.waiting.discard(req_id)
|
||||
|
||||
def add_request(self, req_state: "RequestState"):
|
||||
if (lora_stats := self.get_stats(req_state)) is not None:
|
||||
lora_stats.waiting_requests.add(req_state.request_id)
|
||||
if running:
|
||||
self.running.add(req_id)
|
||||
else:
|
||||
self.running.discard(req_id)
|
||||
|
||||
def finish_request(self, req_state: "RequestState"):
|
||||
if req_state.lora_name is None:
|
||||
@property
|
||||
def empty(self) -> bool:
|
||||
return not (self.waiting or self.running)
|
||||
|
||||
|
||||
class LoRARequestStates:
|
||||
"""A per-LoRA count of running and waiting requests."""
|
||||
|
||||
def __init__(self, log_stats: bool = False):
|
||||
self.log_stats = log_stats
|
||||
self.requests: defaultdict[str, LoRAStats] = defaultdict(LoRAStats)
|
||||
|
||||
def _request_update(
|
||||
self, req_id: str, lora_name: str | None, waiting: bool, running: bool
|
||||
):
|
||||
if not self.log_stats or lora_name is None:
|
||||
return
|
||||
lora_stats = self.lora_name_to_stats[req_state.lora_name]
|
||||
lora_stats.running_requests.remove(req_state.request_id)
|
||||
|
||||
def abort_request(self, req_state: "RequestState"):
|
||||
if req_state.lora_name is None:
|
||||
return
|
||||
lora_stats = self.lora_name_to_stats[req_state.lora_name]
|
||||
lora_stats.waiting_requests.discard(req_state.request_id)
|
||||
lora_stats.running_requests.discard(req_state.request_id)
|
||||
lora_stats = self.requests[lora_name]
|
||||
lora_stats.update(req_id, waiting, running)
|
||||
if lora_stats.empty:
|
||||
del self.requests[lora_name]
|
||||
|
||||
# Break the pattern for this lifecycle methods so we can
|
||||
# call this from IterationStats.update_from_events()
|
||||
@staticmethod
|
||||
def scheduled_request(lora_stats: LoRAStats | None, request_id: str):
|
||||
if lora_stats is None:
|
||||
return
|
||||
lora_stats.waiting_requests.remove(request_id)
|
||||
lora_stats.running_requests.add(request_id)
|
||||
def request_waiting(self, req_id: str, lora_name: str | None):
|
||||
self._request_update(req_id, lora_name, waiting=True, running=False)
|
||||
|
||||
@staticmethod
|
||||
def preempted_request(lora_stats: LoRAStats | None, request_id: str):
|
||||
if lora_stats is None:
|
||||
return
|
||||
lora_stats.running_requests.remove(request_id)
|
||||
lora_stats.waiting_requests.add(request_id)
|
||||
def request_running(self, req_id: str, lora_name: str | None):
|
||||
self._request_update(req_id, lora_name, waiting=False, running=True)
|
||||
|
||||
def update_iteration_stats(self, iteration_stats: IterationStats | None):
|
||||
if iteration_stats is None:
|
||||
def request_finished(self, req_id: str, lora_name: str | None):
|
||||
self._request_update(req_id, lora_name, waiting=False, running=False)
|
||||
|
||||
def update_scheduler_stats(self, scheduler_stats: SchedulerStats | None):
|
||||
if not self.log_stats or scheduler_stats is None:
|
||||
return
|
||||
for lora_name, stats in self.lora_name_to_stats.items():
|
||||
if stats.waiting_requests:
|
||||
iteration_stats.waiting_lora_adapters[lora_name] = len(
|
||||
stats.waiting_requests
|
||||
)
|
||||
if stats.running_requests:
|
||||
iteration_stats.running_lora_adapters[lora_name] = len(
|
||||
stats.running_requests
|
||||
)
|
||||
for lora_name, stats in self.requests.items():
|
||||
scheduler_stats.waiting_lora_adapters[lora_name] = len(stats.waiting)
|
||||
scheduler_stats.running_lora_adapters[lora_name] = len(stats.running)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user