mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 23:34:27 +08:00
[V1][Core][1/n] Logging and Metrics (#11962)
Signed-off-by: rshaw@neuralmagic.com <rshaw@neuralmagic.com>
This commit is contained in:
parent
263a870ee1
commit
9597a095f2
@ -80,7 +80,7 @@ def test_engine_core(monkeypatch):
|
||||
assert len(engine_core.scheduler.running) == 4
|
||||
|
||||
# Loop through until they are all done.
|
||||
while len(engine_core.step()) > 0:
|
||||
while len(engine_core.step().outputs) > 0:
|
||||
pass
|
||||
|
||||
assert len(engine_core.scheduler.waiting) == 0
|
||||
@ -170,7 +170,7 @@ def test_engine_core_advanced_sampling(monkeypatch):
|
||||
assert len(engine_core.scheduler.waiting) == 1
|
||||
assert len(engine_core.scheduler.running) == 0
|
||||
# Loop through until they are all done.
|
||||
while len(engine_core.step()) > 0:
|
||||
while len(engine_core.step().outputs) > 0:
|
||||
pass
|
||||
|
||||
assert len(engine_core.scheduler.waiting) == 0
|
||||
|
||||
@ -43,7 +43,7 @@ def make_request(params: SamplingParams) -> EngineCoreRequest:
|
||||
def loop_until_done(client: EngineCoreClient, outputs: Dict):
|
||||
|
||||
while True:
|
||||
engine_core_outputs = client.get_output()
|
||||
engine_core_outputs = client.get_output().outputs
|
||||
|
||||
if len(engine_core_outputs) == 0:
|
||||
break
|
||||
@ -61,7 +61,7 @@ def loop_until_done(client: EngineCoreClient, outputs: Dict):
|
||||
async def loop_until_done_async(client: EngineCoreClient, outputs: Dict):
|
||||
|
||||
while True:
|
||||
engine_core_outputs = await client.get_output_async()
|
||||
engine_core_outputs = await client.get_output_async().outputs
|
||||
|
||||
if len(engine_core_outputs) == 0:
|
||||
break
|
||||
|
||||
@ -8,7 +8,8 @@ from vllm.logger import init_logger
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.v1.core.encoder_cache_manager import EncoderCacheManager
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheManager
|
||||
from vllm.v1.engine import EngineCoreOutput
|
||||
from vllm.v1.engine import EngineCoreOutput, EngineCoreOutputs
|
||||
from vllm.v1.metrics.stats import SchedulerStats
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
from vllm.v1.request import Request, RequestStatus
|
||||
|
||||
@ -394,12 +395,12 @@ class Scheduler:
|
||||
self,
|
||||
scheduler_output: "SchedulerOutput",
|
||||
model_runner_output: "ModelRunnerOutput",
|
||||
) -> List[EngineCoreOutput]:
|
||||
) -> EngineCoreOutputs:
|
||||
# NOTE(woosuk): This method doesn't consider speculative decoding.
|
||||
sampled_token_ids = model_runner_output.sampled_token_ids
|
||||
num_scheduled_tokens = scheduler_output.num_scheduled_tokens
|
||||
new_running: List[Request] = []
|
||||
engine_core_outputs: List[EngineCoreOutput] = []
|
||||
outputs: List[EngineCoreOutput] = []
|
||||
for request in self.running:
|
||||
req_id = request.request_id
|
||||
request.num_computed_tokens += num_scheduled_tokens[req_id]
|
||||
@ -438,7 +439,7 @@ class Scheduler:
|
||||
finished=request.is_finished(),
|
||||
finish_reason=request.get_finished_reason(),
|
||||
stop_reason=request.stop_reason)
|
||||
engine_core_outputs.append(output)
|
||||
outputs.append(output)
|
||||
|
||||
# Breakout of the loop.
|
||||
if stopped:
|
||||
@ -446,7 +447,10 @@ class Scheduler:
|
||||
|
||||
new_running.append(request)
|
||||
self.running = new_running
|
||||
return engine_core_outputs
|
||||
return EngineCoreOutputs(
|
||||
outputs=outputs,
|
||||
scheduler_stats=self.make_stats(),
|
||||
)
|
||||
|
||||
def _check_stop(self, request: Request) -> bool:
|
||||
if (request.num_tokens >= self.max_model_len
|
||||
@ -515,6 +519,12 @@ class Scheduler:
|
||||
def has_unfinished_requests(self) -> bool:
|
||||
return self.get_num_unfinished_requests() > 0
|
||||
|
||||
def make_stats(self) -> SchedulerStats:
|
||||
return SchedulerStats(
|
||||
num_running_reqs=len(self.running),
|
||||
num_waiting_reqs=len(self.waiting),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class NewRequestData:
|
||||
|
||||
@ -4,6 +4,8 @@ from typing import TYPE_CHECKING, List, Optional, Union
|
||||
|
||||
import msgspec
|
||||
|
||||
from vllm.v1.metrics.stats import SchedulerStats
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.multimodal import MultiModalKwargs
|
||||
@ -56,6 +58,7 @@ class EngineCoreOutputs(
|
||||
|
||||
# [num_reqs]
|
||||
outputs: List[EngineCoreOutput]
|
||||
scheduler_stats: SchedulerStats
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@ -4,7 +4,6 @@ from typing import AsyncGenerator, Dict, List, Mapping, Optional, Type, Union
|
||||
|
||||
from vllm.config import ModelConfig, VllmConfig
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.engine.metrics_types import StatLoggerBase
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.inputs import INPUT_REGISTRY, InputRegistry, PromptType
|
||||
from vllm.inputs.preprocess import InputPreprocessor
|
||||
@ -22,6 +21,8 @@ from vllm.v1.engine.core_client import EngineCoreClient
|
||||
from vllm.v1.engine.detokenizer import Detokenizer
|
||||
from vllm.v1.engine.processor import Processor
|
||||
from vllm.v1.executor.abstract import Executor
|
||||
from vllm.v1.metrics.loggers import LoggingStatLogger, StatLoggerBase
|
||||
from vllm.v1.metrics.stats import SchedulerStats
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -34,7 +35,6 @@ class AsyncLLM(EngineClient):
|
||||
executor_class: Type[Executor],
|
||||
log_stats: bool,
|
||||
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
|
||||
stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
|
||||
input_registry: InputRegistry = INPUT_REGISTRY,
|
||||
use_cached_outputs: bool = False,
|
||||
log_requests: bool = True,
|
||||
@ -45,7 +45,10 @@ class AsyncLLM(EngineClient):
|
||||
|
||||
self.log_requests = log_requests
|
||||
self.log_stats = log_stats
|
||||
self.stat_loggers = stat_loggers
|
||||
self.stat_loggers: List[StatLoggerBase] = [
|
||||
LoggingStatLogger(),
|
||||
# TODO(rob): PrometheusStatLogger(),
|
||||
]
|
||||
self.model_config = vllm_config.model_config
|
||||
|
||||
# Tokenizer (+ ensure liveness if running in another process).
|
||||
@ -82,7 +85,6 @@ class AsyncLLM(EngineClient):
|
||||
asyncio_mode=True,
|
||||
vllm_config=vllm_config,
|
||||
executor_class=executor_class,
|
||||
log_stats=self.log_stats,
|
||||
)
|
||||
|
||||
self.output_handler: Optional[asyncio.Task] = None
|
||||
@ -94,7 +96,6 @@ class AsyncLLM(EngineClient):
|
||||
engine_config: Optional[VllmConfig] = None,
|
||||
start_engine_loop: bool = True,
|
||||
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
|
||||
stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
|
||||
) -> "AsyncLLM":
|
||||
"""Create an AsyncLLM from the EngineArgs."""
|
||||
|
||||
@ -114,7 +115,6 @@ class AsyncLLM(EngineClient):
|
||||
log_stats=not engine_args.disable_log_stats,
|
||||
start_engine_loop=start_engine_loop,
|
||||
usage_context=usage_context,
|
||||
stat_loggers=stat_loggers,
|
||||
)
|
||||
|
||||
def shutdown(self):
|
||||
@ -254,7 +254,8 @@ class AsyncLLM(EngineClient):
|
||||
outputs = await self.engine_core.get_output_async()
|
||||
|
||||
# 2) Detokenize based on the output.
|
||||
request_outputs, reqs_to_abort = self.detokenizer.step(outputs)
|
||||
request_outputs, reqs_to_abort = self.detokenizer.step(
|
||||
outputs.outputs)
|
||||
|
||||
# 3) Put the RequestOutputs into the per-request queues.
|
||||
self._process_request_outputs(request_outputs)
|
||||
@ -262,6 +263,9 @@ class AsyncLLM(EngineClient):
|
||||
# 4) Abort any requests that finished due to stop strings.
|
||||
await self.engine_core.abort_requests_async(reqs_to_abort)
|
||||
|
||||
# 5) Log any stats.
|
||||
await self._log_stats(scheduler_stats=outputs.scheduler_stats)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("EngineCore output handler hit an error: %s", e)
|
||||
kill_process_tree(os.getpid())
|
||||
@ -278,6 +282,14 @@ class AsyncLLM(EngineClient):
|
||||
if request_id in self.rid_to_queue:
|
||||
del self.rid_to_queue[request_id]
|
||||
|
||||
async def _log_stats(self, scheduler_stats: SchedulerStats):
|
||||
"""Log stats to the stat loggers."""
|
||||
if not self.log_stats:
|
||||
return
|
||||
|
||||
for logger in self.stat_loggers:
|
||||
logger.log(scheduler_stats=scheduler_stats)
|
||||
|
||||
def encode(
|
||||
self,
|
||||
prompt: PromptType,
|
||||
|
||||
@ -17,9 +17,9 @@ from vllm.transformers_utils.config import (
|
||||
maybe_register_config_serialize_by_value)
|
||||
from vllm.utils import get_exception_traceback, zmq_socket_ctx
|
||||
from vllm.v1.core.scheduler import Scheduler
|
||||
from vllm.v1.engine import (EngineCoreOutput, EngineCoreOutputs,
|
||||
EngineCoreProfile, EngineCoreRequest,
|
||||
EngineCoreRequestType, EngineCoreRequestUnion)
|
||||
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreProfile,
|
||||
EngineCoreRequest, EngineCoreRequestType,
|
||||
EngineCoreRequestUnion)
|
||||
from vllm.v1.engine.mm_input_mapper import MMInputMapperServer
|
||||
from vllm.v1.executor.abstract import Executor
|
||||
from vllm.v1.request import Request, RequestStatus
|
||||
@ -28,9 +28,7 @@ from vllm.version import __version__ as VLLM_VERSION
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
POLLING_TIMEOUT_MS = 5000
|
||||
POLLING_TIMEOUT_S = POLLING_TIMEOUT_MS // 1000
|
||||
LOGGING_TIME_S = 5
|
||||
POLLING_TIMEOUT_S = 2.5
|
||||
|
||||
|
||||
class EngineCore:
|
||||
@ -40,10 +38,8 @@ class EngineCore:
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
executor_class: Type[Executor],
|
||||
log_stats: bool = False,
|
||||
):
|
||||
assert vllm_config.model_config.runner_type != "pooling"
|
||||
self.log_stats = log_stats
|
||||
|
||||
logger.info("Initializing an LLM engine (v%s) with config: %s",
|
||||
VLLM_VERSION, vllm_config)
|
||||
@ -62,8 +58,6 @@ class EngineCore:
|
||||
vllm_config.cache_config,
|
||||
vllm_config.lora_config)
|
||||
|
||||
self._last_logging_time = time.time()
|
||||
|
||||
self.mm_input_mapper_server = MMInputMapperServer(
|
||||
vllm_config.model_config)
|
||||
|
||||
@ -114,11 +108,12 @@ class EngineCore:
|
||||
self.scheduler.finish_requests(request_ids,
|
||||
RequestStatus.FINISHED_ABORTED)
|
||||
|
||||
def step(self) -> List[EngineCoreOutput]:
|
||||
def step(self) -> EngineCoreOutputs:
|
||||
"""Schedule, execute, and make output."""
|
||||
|
||||
if not self.scheduler.has_unfinished_requests():
|
||||
return []
|
||||
return EngineCoreOutputs(
|
||||
outputs=[], scheduler_stats=self.scheduler.make_stats())
|
||||
|
||||
scheduler_output = self.scheduler.schedule()
|
||||
output = self.model_executor.execute_model(scheduler_output)
|
||||
@ -145,7 +140,9 @@ class EngineCoreProc(EngineCore):
|
||||
executor_class: Type[Executor],
|
||||
log_stats: bool = False,
|
||||
):
|
||||
super().__init__(vllm_config, executor_class, log_stats)
|
||||
super().__init__(vllm_config, executor_class)
|
||||
|
||||
self.log_stats = log_stats
|
||||
|
||||
# Background Threads and Queues for IO. These enable us to
|
||||
# overlap ZMQ socket IO with GPU since they release the GIL,
|
||||
@ -153,7 +150,7 @@ class EngineCoreProc(EngineCore):
|
||||
# model forward pass.
|
||||
# Threads handle Socket <-> Queues and core_busy_loop uses Queue.
|
||||
self.input_queue: queue.Queue[EngineCoreRequestUnion] = queue.Queue()
|
||||
self.output_queue: queue.Queue[List[EngineCoreOutput]] = queue.Queue()
|
||||
self.output_queue: queue.Queue[EngineCoreOutputs] = queue.Queue()
|
||||
threading.Thread(target=self.process_input_socket,
|
||||
args=(input_path, ),
|
||||
daemon=True).start()
|
||||
@ -217,8 +214,10 @@ class EngineCoreProc(EngineCore):
|
||||
self._handle_client_request(req)
|
||||
break
|
||||
except queue.Empty:
|
||||
self._log_stats()
|
||||
logger.debug("EngineCore busy loop waiting.")
|
||||
# Break out the loop so we can log_stats in step().
|
||||
if self.log_stats:
|
||||
break
|
||||
except BaseException:
|
||||
raise
|
||||
|
||||
@ -230,28 +229,9 @@ class EngineCoreProc(EngineCore):
|
||||
# 3) Step the engine core.
|
||||
outputs = self.step()
|
||||
|
||||
# 4) Put EngineCoreOutputs into the output queue.
|
||||
# 5) Put EngineCoreOutputs into the output queue.
|
||||
self.output_queue.put_nowait(outputs)
|
||||
|
||||
self._log_stats()
|
||||
|
||||
def _log_stats(self):
|
||||
"""Log basic stats every LOGGING_TIME_S"""
|
||||
|
||||
if not self.log_stats:
|
||||
return
|
||||
|
||||
now = time.time()
|
||||
|
||||
if now - self._last_logging_time > LOGGING_TIME_S:
|
||||
logger.info(
|
||||
"RUNNING: %s | WAITING: %s",
|
||||
len(self.scheduler.running),
|
||||
len(self.scheduler.waiting),
|
||||
)
|
||||
|
||||
self._last_logging_time = now
|
||||
|
||||
def _handle_client_request(self, request: EngineCoreRequestUnion) -> None:
|
||||
"""Handle EngineCoreRequest or EngineCoreABORT from Client."""
|
||||
|
||||
@ -301,7 +281,6 @@ class EngineCoreProc(EngineCore):
|
||||
|
||||
with zmq_socket_ctx(output_path, zmq.constants.PUSH) as socket:
|
||||
while True:
|
||||
engine_core_outputs = self.output_queue.get()
|
||||
outputs = EngineCoreOutputs(outputs=engine_core_outputs)
|
||||
outputs = self.output_queue.get()
|
||||
encoder.encode_into(outputs, buffer)
|
||||
socket.send_multipart((buffer, ), copy=False)
|
||||
|
||||
@ -12,9 +12,9 @@ from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import (get_open_zmq_ipc_path, kill_process_tree,
|
||||
make_zmq_socket)
|
||||
from vllm.v1.engine import (EngineCoreOutput, EngineCoreOutputs,
|
||||
EngineCoreProfile, EngineCoreRequest,
|
||||
EngineCoreRequestType, EngineCoreRequestUnion)
|
||||
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreProfile,
|
||||
EngineCoreRequest, EngineCoreRequestType,
|
||||
EngineCoreRequestUnion)
|
||||
from vllm.v1.engine.core import EngineCore, EngineCoreProc
|
||||
from vllm.v1.executor.abstract import Executor
|
||||
from vllm.v1.serial_utils import PickleEncoder
|
||||
@ -40,7 +40,6 @@ class EngineCoreClient(ABC):
|
||||
asyncio_mode: bool,
|
||||
vllm_config: VllmConfig,
|
||||
executor_class: Type[Executor],
|
||||
log_stats: bool = False,
|
||||
) -> "EngineCoreClient":
|
||||
|
||||
# TODO: support this for debugging purposes.
|
||||
@ -50,18 +49,18 @@ class EngineCoreClient(ABC):
|
||||
"is not currently supported.")
|
||||
|
||||
if multiprocess_mode and asyncio_mode:
|
||||
return AsyncMPClient(vllm_config, executor_class, log_stats)
|
||||
return AsyncMPClient(vllm_config, executor_class)
|
||||
|
||||
if multiprocess_mode and not asyncio_mode:
|
||||
return SyncMPClient(vllm_config, executor_class, log_stats)
|
||||
return SyncMPClient(vllm_config, executor_class)
|
||||
|
||||
return InprocClient(vllm_config, executor_class, log_stats)
|
||||
return InprocClient(vllm_config, executor_class)
|
||||
|
||||
@abstractmethod
|
||||
def shutdown(self):
|
||||
...
|
||||
|
||||
def get_output(self) -> List[EngineCoreOutput]:
|
||||
def get_output(self) -> EngineCoreOutputs:
|
||||
raise NotImplementedError
|
||||
|
||||
def add_request(self, request: EngineCoreRequest) -> None:
|
||||
@ -73,7 +72,7 @@ class EngineCoreClient(ABC):
|
||||
def abort_requests(self, request_ids: List[str]) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
async def get_output_async(self) -> List[EngineCoreOutput]:
|
||||
async def get_output_async(self) -> EngineCoreOutputs:
|
||||
raise NotImplementedError
|
||||
|
||||
async def add_request_async(self, request: EngineCoreRequest) -> None:
|
||||
@ -99,7 +98,7 @@ class InprocClient(EngineCoreClient):
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.engine_core = EngineCore(*args, **kwargs)
|
||||
|
||||
def get_output(self) -> List[EngineCoreOutput]:
|
||||
def get_output(self) -> EngineCoreOutputs:
|
||||
return self.engine_core.step()
|
||||
|
||||
def add_request(self, request: EngineCoreRequest) -> None:
|
||||
@ -133,7 +132,7 @@ class MPClient(EngineCoreClient):
|
||||
asyncio_mode: bool,
|
||||
vllm_config: VllmConfig,
|
||||
executor_class: Type[Executor],
|
||||
log_stats: bool = False,
|
||||
log_stats: bool,
|
||||
):
|
||||
# The child processes will send SIGUSR1 when unrecoverable
|
||||
# errors happen. We kill the process tree here so that the
|
||||
@ -194,22 +193,19 @@ class MPClient(EngineCoreClient):
|
||||
class SyncMPClient(MPClient):
|
||||
"""Synchronous client for multi-proc EngineCore."""
|
||||
|
||||
def __init__(self,
|
||||
vllm_config: VllmConfig,
|
||||
executor_class: Type[Executor],
|
||||
log_stats: bool = False):
|
||||
def __init__(self, vllm_config: VllmConfig,
|
||||
executor_class: Type[Executor]):
|
||||
super().__init__(
|
||||
asyncio_mode=False,
|
||||
vllm_config=vllm_config,
|
||||
executor_class=executor_class,
|
||||
log_stats=log_stats,
|
||||
log_stats=False,
|
||||
)
|
||||
|
||||
def get_output(self) -> List[EngineCoreOutput]:
|
||||
def get_output(self) -> EngineCoreOutputs:
|
||||
|
||||
(frame, ) = self.output_socket.recv_multipart(copy=False)
|
||||
engine_core_outputs = self.decoder.decode(frame.buffer).outputs
|
||||
return engine_core_outputs
|
||||
return self.decoder.decode(frame.buffer)
|
||||
|
||||
def _send_input(self, request_type: EngineCoreRequestType,
|
||||
request: EngineCoreRequestUnion) -> None:
|
||||
@ -235,23 +231,19 @@ class SyncMPClient(MPClient):
|
||||
class AsyncMPClient(MPClient):
|
||||
"""Asyncio-compatible client for multi-proc EngineCore."""
|
||||
|
||||
def __init__(self,
|
||||
vllm_config: VllmConfig,
|
||||
executor_class: Type[Executor],
|
||||
log_stats: bool = False):
|
||||
def __init__(self, vllm_config: VllmConfig,
|
||||
executor_class: Type[Executor]):
|
||||
super().__init__(
|
||||
asyncio_mode=True,
|
||||
vllm_config=vllm_config,
|
||||
executor_class=executor_class,
|
||||
log_stats=log_stats,
|
||||
log_stats=True,
|
||||
)
|
||||
|
||||
async def get_output_async(self) -> List[EngineCoreOutput]:
|
||||
async def get_output_async(self) -> EngineCoreOutputs:
|
||||
|
||||
frames = await self.output_socket.recv_multipart(copy=False)
|
||||
engine_core_outputs = self.decoder.decode(frames[0].buffer).outputs
|
||||
|
||||
return engine_core_outputs
|
||||
return self.decoder.decode(frames[0].buffer)
|
||||
|
||||
async def _send_input(self, request_type: EngineCoreRequestType,
|
||||
request: EngineCoreRequestUnion) -> None:
|
||||
|
||||
@ -74,7 +74,6 @@ class LLMEngine:
|
||||
asyncio_mode=False,
|
||||
vllm_config=vllm_config,
|
||||
executor_class=executor_class,
|
||||
log_stats=False,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@ -147,11 +146,11 @@ class LLMEngine:
|
||||
def step(self) -> List[RequestOutput]:
|
||||
|
||||
# 1) Get EngineCoreOutput from the EngineCore.
|
||||
engine_core_outputs = self.engine_core.get_output()
|
||||
outputs = self.engine_core.get_output()
|
||||
|
||||
# 2) Detokenizer the EngineCoreOutput.
|
||||
request_outputs, requests_to_abort = self.detokenizer.step(
|
||||
engine_core_outputs)
|
||||
outputs.outputs)
|
||||
|
||||
# 3) Abort requests that finished due to stopping criteria.
|
||||
if requests_to_abort:
|
||||
|
||||
0
vllm/v1/metrics/__init__.py
Normal file
0
vllm/v1/metrics/__init__.py
Normal file
38
vllm/v1/metrics/loggers.py
Normal file
38
vllm/v1/metrics/loggers.py
Normal file
@ -0,0 +1,38 @@
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.metrics.stats import SchedulerStats
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
_LOCAL_LOGGING_INTERVAL_SEC = 5.0
|
||||
|
||||
|
||||
class StatLoggerBase(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def log(self, scheduler_stats: SchedulerStats):
|
||||
...
|
||||
|
||||
|
||||
class LoggingStatLogger(StatLoggerBase):
|
||||
|
||||
def __init__(self):
|
||||
self.last_log_time = time.monotonic()
|
||||
|
||||
def log(self, scheduler_stats: SchedulerStats):
|
||||
"""Log Stats to standard output."""
|
||||
|
||||
# Log every _LOCAL_LOGGING_INTERVAL_SEC.
|
||||
now = time.monotonic()
|
||||
if now - self.last_log_time < _LOCAL_LOGGING_INTERVAL_SEC:
|
||||
return
|
||||
self.last_log_time = now
|
||||
|
||||
# Format and print output.
|
||||
logger.info(
|
||||
"Running: %d reqs, Waiting: %d reqs ",
|
||||
scheduler_stats.num_running_reqs,
|
||||
scheduler_stats.num_waiting_reqs,
|
||||
)
|
||||
12
vllm/v1/metrics/stats.py
Normal file
12
vllm/v1/metrics/stats.py
Normal file
@ -0,0 +1,12 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class SchedulerStats:
|
||||
"""Stats associated with the scheduler."""
|
||||
|
||||
num_running_reqs: int = 0
|
||||
num_waiting_reqs: int = 0
|
||||
|
||||
# gpu_cache_usage: float = 0.0
|
||||
# gpu_prefix_cache_hit_rate: float = 0.0
|
||||
Loading…
x
Reference in New Issue
Block a user