[V1][Core][1/n] Logging and Metrics (#11962)

Signed-off-by: rshaw@neuralmagic.com <rshaw@neuralmagic.com>
This commit is contained in:
Robert Shaw 2025-01-12 16:02:02 -05:00 committed by GitHub
parent 263a870ee1
commit 9597a095f2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 129 additions and 84 deletions

View File

@ -80,7 +80,7 @@ def test_engine_core(monkeypatch):
assert len(engine_core.scheduler.running) == 4
# Loop through until they are all done.
while len(engine_core.step()) > 0:
while len(engine_core.step().outputs) > 0:
pass
assert len(engine_core.scheduler.waiting) == 0
@ -170,7 +170,7 @@ def test_engine_core_advanced_sampling(monkeypatch):
assert len(engine_core.scheduler.waiting) == 1
assert len(engine_core.scheduler.running) == 0
# Loop through until they are all done.
while len(engine_core.step()) > 0:
while len(engine_core.step().outputs) > 0:
pass
assert len(engine_core.scheduler.waiting) == 0

View File

@ -43,7 +43,7 @@ def make_request(params: SamplingParams) -> EngineCoreRequest:
def loop_until_done(client: EngineCoreClient, outputs: Dict):
while True:
engine_core_outputs = client.get_output()
engine_core_outputs = client.get_output().outputs
if len(engine_core_outputs) == 0:
break
@ -61,7 +61,7 @@ def loop_until_done(client: EngineCoreClient, outputs: Dict):
async def loop_until_done_async(client: EngineCoreClient, outputs: Dict):
while True:
engine_core_outputs = await client.get_output_async()
engine_core_outputs = await client.get_output_async().outputs
if len(engine_core_outputs) == 0:
break

View File

@ -8,7 +8,8 @@ from vllm.logger import init_logger
from vllm.sampling_params import SamplingParams
from vllm.v1.core.encoder_cache_manager import EncoderCacheManager
from vllm.v1.core.kv_cache_manager import KVCacheManager
from vllm.v1.engine import EngineCoreOutput
from vllm.v1.engine import EngineCoreOutput, EngineCoreOutputs
from vllm.v1.metrics.stats import SchedulerStats
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.request import Request, RequestStatus
@ -394,12 +395,12 @@ class Scheduler:
self,
scheduler_output: "SchedulerOutput",
model_runner_output: "ModelRunnerOutput",
) -> List[EngineCoreOutput]:
) -> EngineCoreOutputs:
# NOTE(woosuk): This method doesn't consider speculative decoding.
sampled_token_ids = model_runner_output.sampled_token_ids
num_scheduled_tokens = scheduler_output.num_scheduled_tokens
new_running: List[Request] = []
engine_core_outputs: List[EngineCoreOutput] = []
outputs: List[EngineCoreOutput] = []
for request in self.running:
req_id = request.request_id
request.num_computed_tokens += num_scheduled_tokens[req_id]
@ -438,7 +439,7 @@ class Scheduler:
finished=request.is_finished(),
finish_reason=request.get_finished_reason(),
stop_reason=request.stop_reason)
engine_core_outputs.append(output)
outputs.append(output)
# Breakout of the loop.
if stopped:
@ -446,7 +447,10 @@ class Scheduler:
new_running.append(request)
self.running = new_running
return engine_core_outputs
return EngineCoreOutputs(
outputs=outputs,
scheduler_stats=self.make_stats(),
)
def _check_stop(self, request: Request) -> bool:
if (request.num_tokens >= self.max_model_len
@ -515,6 +519,12 @@ class Scheduler:
def has_unfinished_requests(self) -> bool:
return self.get_num_unfinished_requests() > 0
def make_stats(self) -> SchedulerStats:
return SchedulerStats(
num_running_reqs=len(self.running),
num_waiting_reqs=len(self.waiting),
)
@dataclass
class NewRequestData:

View File

@ -4,6 +4,8 @@ from typing import TYPE_CHECKING, List, Optional, Union
import msgspec
from vllm.v1.metrics.stats import SchedulerStats
if TYPE_CHECKING:
from vllm.lora.request import LoRARequest
from vllm.multimodal import MultiModalKwargs
@ -56,6 +58,7 @@ class EngineCoreOutputs(
# [num_reqs]
outputs: List[EngineCoreOutput]
scheduler_stats: SchedulerStats
@dataclass

View File

@ -4,7 +4,6 @@ from typing import AsyncGenerator, Dict, List, Mapping, Optional, Type, Union
from vllm.config import ModelConfig, VllmConfig
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.metrics_types import StatLoggerBase
from vllm.engine.protocol import EngineClient
from vllm.inputs import INPUT_REGISTRY, InputRegistry, PromptType
from vllm.inputs.preprocess import InputPreprocessor
@ -22,6 +21,8 @@ from vllm.v1.engine.core_client import EngineCoreClient
from vllm.v1.engine.detokenizer import Detokenizer
from vllm.v1.engine.processor import Processor
from vllm.v1.executor.abstract import Executor
from vllm.v1.metrics.loggers import LoggingStatLogger, StatLoggerBase
from vllm.v1.metrics.stats import SchedulerStats
logger = init_logger(__name__)
@ -34,7 +35,6 @@ class AsyncLLM(EngineClient):
executor_class: Type[Executor],
log_stats: bool,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
input_registry: InputRegistry = INPUT_REGISTRY,
use_cached_outputs: bool = False,
log_requests: bool = True,
@ -45,7 +45,10 @@ class AsyncLLM(EngineClient):
self.log_requests = log_requests
self.log_stats = log_stats
self.stat_loggers = stat_loggers
self.stat_loggers: List[StatLoggerBase] = [
LoggingStatLogger(),
# TODO(rob): PrometheusStatLogger(),
]
self.model_config = vllm_config.model_config
# Tokenizer (+ ensure liveness if running in another process).
@ -82,7 +85,6 @@ class AsyncLLM(EngineClient):
asyncio_mode=True,
vllm_config=vllm_config,
executor_class=executor_class,
log_stats=self.log_stats,
)
self.output_handler: Optional[asyncio.Task] = None
@ -94,7 +96,6 @@ class AsyncLLM(EngineClient):
engine_config: Optional[VllmConfig] = None,
start_engine_loop: bool = True,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
) -> "AsyncLLM":
"""Create an AsyncLLM from the EngineArgs."""
@ -114,7 +115,6 @@ class AsyncLLM(EngineClient):
log_stats=not engine_args.disable_log_stats,
start_engine_loop=start_engine_loop,
usage_context=usage_context,
stat_loggers=stat_loggers,
)
def shutdown(self):
@ -254,7 +254,8 @@ class AsyncLLM(EngineClient):
outputs = await self.engine_core.get_output_async()
# 2) Detokenize based on the output.
request_outputs, reqs_to_abort = self.detokenizer.step(outputs)
request_outputs, reqs_to_abort = self.detokenizer.step(
outputs.outputs)
# 3) Put the RequestOutputs into the per-request queues.
self._process_request_outputs(request_outputs)
@ -262,6 +263,9 @@ class AsyncLLM(EngineClient):
# 4) Abort any requests that finished due to stop strings.
await self.engine_core.abort_requests_async(reqs_to_abort)
# 5) Log any stats.
await self._log_stats(scheduler_stats=outputs.scheduler_stats)
except Exception as e:
logger.exception("EngineCore output handler hit an error: %s", e)
kill_process_tree(os.getpid())
@ -278,6 +282,14 @@ class AsyncLLM(EngineClient):
if request_id in self.rid_to_queue:
del self.rid_to_queue[request_id]
async def _log_stats(self, scheduler_stats: SchedulerStats):
"""Log stats to the stat loggers."""
if not self.log_stats:
return
for logger in self.stat_loggers:
logger.log(scheduler_stats=scheduler_stats)
def encode(
self,
prompt: PromptType,

View File

@ -17,9 +17,9 @@ from vllm.transformers_utils.config import (
maybe_register_config_serialize_by_value)
from vllm.utils import get_exception_traceback, zmq_socket_ctx
from vllm.v1.core.scheduler import Scheduler
from vllm.v1.engine import (EngineCoreOutput, EngineCoreOutputs,
EngineCoreProfile, EngineCoreRequest,
EngineCoreRequestType, EngineCoreRequestUnion)
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreProfile,
EngineCoreRequest, EngineCoreRequestType,
EngineCoreRequestUnion)
from vllm.v1.engine.mm_input_mapper import MMInputMapperServer
from vllm.v1.executor.abstract import Executor
from vllm.v1.request import Request, RequestStatus
@ -28,9 +28,7 @@ from vllm.version import __version__ as VLLM_VERSION
logger = init_logger(__name__)
POLLING_TIMEOUT_MS = 5000
POLLING_TIMEOUT_S = POLLING_TIMEOUT_MS // 1000
LOGGING_TIME_S = 5
POLLING_TIMEOUT_S = 2.5
class EngineCore:
@ -40,10 +38,8 @@ class EngineCore:
self,
vllm_config: VllmConfig,
executor_class: Type[Executor],
log_stats: bool = False,
):
assert vllm_config.model_config.runner_type != "pooling"
self.log_stats = log_stats
logger.info("Initializing an LLM engine (v%s) with config: %s",
VLLM_VERSION, vllm_config)
@ -62,8 +58,6 @@ class EngineCore:
vllm_config.cache_config,
vllm_config.lora_config)
self._last_logging_time = time.time()
self.mm_input_mapper_server = MMInputMapperServer(
vllm_config.model_config)
@ -114,11 +108,12 @@ class EngineCore:
self.scheduler.finish_requests(request_ids,
RequestStatus.FINISHED_ABORTED)
def step(self) -> List[EngineCoreOutput]:
def step(self) -> EngineCoreOutputs:
"""Schedule, execute, and make output."""
if not self.scheduler.has_unfinished_requests():
return []
return EngineCoreOutputs(
outputs=[], scheduler_stats=self.scheduler.make_stats())
scheduler_output = self.scheduler.schedule()
output = self.model_executor.execute_model(scheduler_output)
@ -145,7 +140,9 @@ class EngineCoreProc(EngineCore):
executor_class: Type[Executor],
log_stats: bool = False,
):
super().__init__(vllm_config, executor_class, log_stats)
super().__init__(vllm_config, executor_class)
self.log_stats = log_stats
# Background Threads and Queues for IO. These enable us to
# overlap ZMQ socket IO with GPU since they release the GIL,
@ -153,7 +150,7 @@ class EngineCoreProc(EngineCore):
# model forward pass.
# Threads handle Socket <-> Queues and core_busy_loop uses Queue.
self.input_queue: queue.Queue[EngineCoreRequestUnion] = queue.Queue()
self.output_queue: queue.Queue[List[EngineCoreOutput]] = queue.Queue()
self.output_queue: queue.Queue[EngineCoreOutputs] = queue.Queue()
threading.Thread(target=self.process_input_socket,
args=(input_path, ),
daemon=True).start()
@ -217,8 +214,10 @@ class EngineCoreProc(EngineCore):
self._handle_client_request(req)
break
except queue.Empty:
self._log_stats()
logger.debug("EngineCore busy loop waiting.")
# Break out the loop so we can log_stats in step().
if self.log_stats:
break
except BaseException:
raise
@ -230,28 +229,9 @@ class EngineCoreProc(EngineCore):
# 3) Step the engine core.
outputs = self.step()
# 4) Put EngineCoreOutputs into the output queue.
# 5) Put EngineCoreOutputs into the output queue.
self.output_queue.put_nowait(outputs)
self._log_stats()
def _log_stats(self):
"""Log basic stats every LOGGING_TIME_S"""
if not self.log_stats:
return
now = time.time()
if now - self._last_logging_time > LOGGING_TIME_S:
logger.info(
"RUNNING: %s | WAITING: %s",
len(self.scheduler.running),
len(self.scheduler.waiting),
)
self._last_logging_time = now
def _handle_client_request(self, request: EngineCoreRequestUnion) -> None:
"""Handle EngineCoreRequest or EngineCoreABORT from Client."""
@ -301,7 +281,6 @@ class EngineCoreProc(EngineCore):
with zmq_socket_ctx(output_path, zmq.constants.PUSH) as socket:
while True:
engine_core_outputs = self.output_queue.get()
outputs = EngineCoreOutputs(outputs=engine_core_outputs)
outputs = self.output_queue.get()
encoder.encode_into(outputs, buffer)
socket.send_multipart((buffer, ), copy=False)

View File

@ -12,9 +12,9 @@ from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.utils import (get_open_zmq_ipc_path, kill_process_tree,
make_zmq_socket)
from vllm.v1.engine import (EngineCoreOutput, EngineCoreOutputs,
EngineCoreProfile, EngineCoreRequest,
EngineCoreRequestType, EngineCoreRequestUnion)
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreProfile,
EngineCoreRequest, EngineCoreRequestType,
EngineCoreRequestUnion)
from vllm.v1.engine.core import EngineCore, EngineCoreProc
from vllm.v1.executor.abstract import Executor
from vllm.v1.serial_utils import PickleEncoder
@ -40,7 +40,6 @@ class EngineCoreClient(ABC):
asyncio_mode: bool,
vllm_config: VllmConfig,
executor_class: Type[Executor],
log_stats: bool = False,
) -> "EngineCoreClient":
# TODO: support this for debugging purposes.
@ -50,18 +49,18 @@ class EngineCoreClient(ABC):
"is not currently supported.")
if multiprocess_mode and asyncio_mode:
return AsyncMPClient(vllm_config, executor_class, log_stats)
return AsyncMPClient(vllm_config, executor_class)
if multiprocess_mode and not asyncio_mode:
return SyncMPClient(vllm_config, executor_class, log_stats)
return SyncMPClient(vllm_config, executor_class)
return InprocClient(vllm_config, executor_class, log_stats)
return InprocClient(vllm_config, executor_class)
@abstractmethod
def shutdown(self):
...
def get_output(self) -> List[EngineCoreOutput]:
def get_output(self) -> EngineCoreOutputs:
raise NotImplementedError
def add_request(self, request: EngineCoreRequest) -> None:
@ -73,7 +72,7 @@ class EngineCoreClient(ABC):
def abort_requests(self, request_ids: List[str]) -> None:
raise NotImplementedError
async def get_output_async(self) -> List[EngineCoreOutput]:
async def get_output_async(self) -> EngineCoreOutputs:
raise NotImplementedError
async def add_request_async(self, request: EngineCoreRequest) -> None:
@ -99,7 +98,7 @@ class InprocClient(EngineCoreClient):
def __init__(self, *args, **kwargs):
self.engine_core = EngineCore(*args, **kwargs)
def get_output(self) -> List[EngineCoreOutput]:
def get_output(self) -> EngineCoreOutputs:
return self.engine_core.step()
def add_request(self, request: EngineCoreRequest) -> None:
@ -133,7 +132,7 @@ class MPClient(EngineCoreClient):
asyncio_mode: bool,
vllm_config: VllmConfig,
executor_class: Type[Executor],
log_stats: bool = False,
log_stats: bool,
):
# The child processes will send SIGUSR1 when unrecoverable
# errors happen. We kill the process tree here so that the
@ -194,22 +193,19 @@ class MPClient(EngineCoreClient):
class SyncMPClient(MPClient):
"""Synchronous client for multi-proc EngineCore."""
def __init__(self,
vllm_config: VllmConfig,
executor_class: Type[Executor],
log_stats: bool = False):
def __init__(self, vllm_config: VllmConfig,
executor_class: Type[Executor]):
super().__init__(
asyncio_mode=False,
vllm_config=vllm_config,
executor_class=executor_class,
log_stats=log_stats,
log_stats=False,
)
def get_output(self) -> List[EngineCoreOutput]:
def get_output(self) -> EngineCoreOutputs:
(frame, ) = self.output_socket.recv_multipart(copy=False)
engine_core_outputs = self.decoder.decode(frame.buffer).outputs
return engine_core_outputs
return self.decoder.decode(frame.buffer)
def _send_input(self, request_type: EngineCoreRequestType,
request: EngineCoreRequestUnion) -> None:
@ -235,23 +231,19 @@ class SyncMPClient(MPClient):
class AsyncMPClient(MPClient):
"""Asyncio-compatible client for multi-proc EngineCore."""
def __init__(self,
vllm_config: VllmConfig,
executor_class: Type[Executor],
log_stats: bool = False):
def __init__(self, vllm_config: VllmConfig,
executor_class: Type[Executor]):
super().__init__(
asyncio_mode=True,
vllm_config=vllm_config,
executor_class=executor_class,
log_stats=log_stats,
log_stats=True,
)
async def get_output_async(self) -> List[EngineCoreOutput]:
async def get_output_async(self) -> EngineCoreOutputs:
frames = await self.output_socket.recv_multipart(copy=False)
engine_core_outputs = self.decoder.decode(frames[0].buffer).outputs
return engine_core_outputs
return self.decoder.decode(frames[0].buffer)
async def _send_input(self, request_type: EngineCoreRequestType,
request: EngineCoreRequestUnion) -> None:

View File

@ -74,7 +74,6 @@ class LLMEngine:
asyncio_mode=False,
vllm_config=vllm_config,
executor_class=executor_class,
log_stats=False,
)
@classmethod
@ -147,11 +146,11 @@ class LLMEngine:
def step(self) -> List[RequestOutput]:
# 1) Get EngineCoreOutput from the EngineCore.
engine_core_outputs = self.engine_core.get_output()
outputs = self.engine_core.get_output()
# 2) Detokenizer the EngineCoreOutput.
request_outputs, requests_to_abort = self.detokenizer.step(
engine_core_outputs)
outputs.outputs)
# 3) Abort requests that finished due to stopping criteria.
if requests_to_abort:

View File

View File

@ -0,0 +1,38 @@
import time
from abc import ABC, abstractmethod
from vllm.logger import init_logger
from vllm.v1.metrics.stats import SchedulerStats
logger = init_logger(__name__)
_LOCAL_LOGGING_INTERVAL_SEC = 5.0
class StatLoggerBase(ABC):
@abstractmethod
def log(self, scheduler_stats: SchedulerStats):
...
class LoggingStatLogger(StatLoggerBase):
def __init__(self):
self.last_log_time = time.monotonic()
def log(self, scheduler_stats: SchedulerStats):
"""Log Stats to standard output."""
# Log every _LOCAL_LOGGING_INTERVAL_SEC.
now = time.monotonic()
if now - self.last_log_time < _LOCAL_LOGGING_INTERVAL_SEC:
return
self.last_log_time = now
# Format and print output.
logger.info(
"Running: %d reqs, Waiting: %d reqs ",
scheduler_stats.num_running_reqs,
scheduler_stats.num_waiting_reqs,
)

12
vllm/v1/metrics/stats.py Normal file
View File

@ -0,0 +1,12 @@
from dataclasses import dataclass
@dataclass
class SchedulerStats:
"""Stats associated with the scheduler."""
num_running_reqs: int = 0
num_waiting_reqs: int = 0
# gpu_cache_usage: float = 0.0
# gpu_prefix_cache_hit_rate: float = 0.0