mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-13 18:05:49 +08:00
[V1][Metrics] Allow V1 AsyncLLM to use custom logger (#14661)
Signed-off-by: Zijing Liu <liuzijing2014@gmail.com> Signed-off-by: Mark McLoughlin <markmc@redhat.com> Signed-off-by: Nick Hill <nhill@redhat.com> Co-authored-by: Mark McLoughlin <markmc@redhat.com> Co-authored-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
parent
54271bb766
commit
53e8cf53a4
@ -3,16 +3,19 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
from contextlib import ExitStack
|
from contextlib import ExitStack
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from vllm import SamplingParams
|
from vllm import SamplingParams
|
||||||
from vllm.assets.image import ImageAsset
|
from vllm.assets.image import ImageAsset
|
||||||
|
from vllm.config import VllmConfig
|
||||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||||
from vllm.inputs import PromptType
|
from vllm.inputs import PromptType
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.sampling_params import RequestOutputKind
|
from vllm.sampling_params import RequestOutputKind
|
||||||
from vllm.v1.engine.async_llm import AsyncLLM
|
from vllm.v1.engine.async_llm import AsyncLLM
|
||||||
|
from vllm.v1.metrics.loggers import LoggingStatLogger
|
||||||
|
|
||||||
if not current_platform.is_cuda():
|
if not current_platform.is_cuda():
|
||||||
pytest.skip(reason="V1 currently only supported on CUDA.",
|
pytest.skip(reason="V1 currently only supported on CUDA.",
|
||||||
@ -216,3 +219,33 @@ async def test_finished_flag(monkeypatch: pytest.MonkeyPatch, n: int,
|
|||||||
# Assert only the last output has the finished flag set
|
# Assert only the last output has the finished flag set
|
||||||
assert all(not out.finished for out in outputs[:-1])
|
assert all(not out.finished for out in outputs[:-1])
|
||||||
assert outputs[-1].finished
|
assert outputs[-1].finished
|
||||||
|
|
||||||
|
|
||||||
|
class MockLoggingStatLogger(LoggingStatLogger):
|
||||||
|
|
||||||
|
def __init__(self, vllm_config: VllmConfig, engine_index: int = 0):
|
||||||
|
super().__init__(vllm_config, engine_index)
|
||||||
|
self.log = MagicMock()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_customize_loggers(monkeypatch):
|
||||||
|
"""Test that we can customize the loggers.
|
||||||
|
If a customized logger is provided at the init, it should
|
||||||
|
be used directly.
|
||||||
|
"""
|
||||||
|
|
||||||
|
with monkeypatch.context() as m, ExitStack() as after:
|
||||||
|
m.setenv("VLLM_USE_V1", "1")
|
||||||
|
|
||||||
|
engine = AsyncLLM.from_engine_args(
|
||||||
|
TEXT_ENGINE_ARGS,
|
||||||
|
stat_loggers=[MockLoggingStatLogger],
|
||||||
|
)
|
||||||
|
after.callback(engine.shutdown)
|
||||||
|
|
||||||
|
await engine.do_log_stats()
|
||||||
|
|
||||||
|
assert len(engine.stat_loggers) == 1
|
||||||
|
assert len(engine.stat_loggers[0]) == 1
|
||||||
|
engine.stat_loggers[0][0].log.assert_called_once()
|
||||||
|
|||||||
@ -1,6 +1,5 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
|
||||||
from collections.abc import AsyncGenerator, Mapping
|
from collections.abc import AsyncGenerator, Mapping
|
||||||
from copy import copy
|
from copy import copy
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
@ -33,8 +32,8 @@ from vllm.v1.engine.output_processor import (OutputProcessor,
|
|||||||
from vllm.v1.engine.parallel_sampling import ParentRequest
|
from vllm.v1.engine.parallel_sampling import ParentRequest
|
||||||
from vllm.v1.engine.processor import Processor
|
from vllm.v1.engine.processor import Processor
|
||||||
from vllm.v1.executor.abstract import Executor
|
from vllm.v1.executor.abstract import Executor
|
||||||
from vllm.v1.metrics.loggers import (LoggingStatLogger, PrometheusStatLogger,
|
from vllm.v1.metrics.loggers import (StatLoggerBase, StatLoggerFactory,
|
||||||
StatLoggerBase)
|
setup_default_loggers)
|
||||||
from vllm.v1.metrics.stats import IterationStats, SchedulerStats
|
from vllm.v1.metrics.stats import IterationStats, SchedulerStats
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
@ -52,7 +51,28 @@ class AsyncLLM(EngineClient):
|
|||||||
use_cached_outputs: bool = False,
|
use_cached_outputs: bool = False,
|
||||||
log_requests: bool = True,
|
log_requests: bool = True,
|
||||||
start_engine_loop: bool = True,
|
start_engine_loop: bool = True,
|
||||||
|
stat_loggers: Optional[list[StatLoggerFactory]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
"""
|
||||||
|
Create an AsyncLLM.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vllm_config: global configuration.
|
||||||
|
executor_class: an Executor impl, e.g. MultiprocExecutor.
|
||||||
|
log_stats: Whether to log stats.
|
||||||
|
usage_context: Usage context of the LLM.
|
||||||
|
mm_registry: Multi-modal registry.
|
||||||
|
use_cached_outputs: Whether to use cached outputs.
|
||||||
|
log_requests: Whether to log requests.
|
||||||
|
start_engine_loop: Whether to start the engine loop.
|
||||||
|
stat_loggers: customized stat loggers for the engine.
|
||||||
|
If not provided, default stat loggers will be used.
|
||||||
|
PLEASE BE AWARE THAT STAT LOGGER IS NOT STABLE
|
||||||
|
IN V1, AND ITS BASE CLASS INTERFACE MIGHT CHANGE.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None
|
||||||
|
"""
|
||||||
if not envs.VLLM_USE_V1:
|
if not envs.VLLM_USE_V1:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Using V1 AsyncLLMEngine, but envs.VLLM_USE_V1=False. "
|
"Using V1 AsyncLLMEngine, but envs.VLLM_USE_V1=False. "
|
||||||
@ -66,15 +86,12 @@ class AsyncLLM(EngineClient):
|
|||||||
self.log_stats = log_stats
|
self.log_stats = log_stats
|
||||||
|
|
||||||
# Set up stat loggers; independent set for each DP rank.
|
# Set up stat loggers; independent set for each DP rank.
|
||||||
self.stat_loggers: list[list[StatLoggerBase]] = []
|
self.stat_loggers: list[list[StatLoggerBase]] = setup_default_loggers(
|
||||||
if self.log_stats:
|
vllm_config=vllm_config,
|
||||||
for i in range(vllm_config.parallel_config.data_parallel_size):
|
log_stats=self.log_stats,
|
||||||
loggers: list[StatLoggerBase] = []
|
engine_num=vllm_config.parallel_config.data_parallel_size,
|
||||||
if logger.isEnabledFor(logging.INFO):
|
custom_stat_loggers=stat_loggers,
|
||||||
loggers.append(LoggingStatLogger(engine_index=i))
|
)
|
||||||
loggers.append(
|
|
||||||
PrometheusStatLogger(vllm_config, engine_index=i))
|
|
||||||
self.stat_loggers.append(loggers)
|
|
||||||
|
|
||||||
# Tokenizer (+ ensure liveness if running in another process).
|
# Tokenizer (+ ensure liveness if running in another process).
|
||||||
self.tokenizer = init_tokenizer_from_configs(
|
self.tokenizer = init_tokenizer_from_configs(
|
||||||
@ -118,7 +135,7 @@ class AsyncLLM(EngineClient):
|
|||||||
vllm_config: VllmConfig,
|
vllm_config: VllmConfig,
|
||||||
start_engine_loop: bool = True,
|
start_engine_loop: bool = True,
|
||||||
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
|
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
|
||||||
stat_loggers: Optional[dict[str, StatLoggerBase]] = None,
|
stat_loggers: Optional[list[StatLoggerFactory]] = None,
|
||||||
disable_log_requests: bool = False,
|
disable_log_requests: bool = False,
|
||||||
disable_log_stats: bool = False,
|
disable_log_stats: bool = False,
|
||||||
) -> "AsyncLLM":
|
) -> "AsyncLLM":
|
||||||
@ -129,17 +146,12 @@ class AsyncLLM(EngineClient):
|
|||||||
"AsyncLLMEngine.from_vllm_config(...) or explicitly set "
|
"AsyncLLMEngine.from_vllm_config(...) or explicitly set "
|
||||||
"VLLM_USE_V1=0 or 1 and report this issue on Github.")
|
"VLLM_USE_V1=0 or 1 and report this issue on Github.")
|
||||||
|
|
||||||
# FIXME(rob): refactor VllmConfig to include the StatLoggers
|
|
||||||
# include StatLogger in the Oracle decision.
|
|
||||||
if stat_loggers is not None:
|
|
||||||
raise ValueError("Custom StatLoggers are not yet supported on V1. "
|
|
||||||
"Explicitly set VLLM_USE_V1=0 to disable V1.")
|
|
||||||
|
|
||||||
# Create the LLMEngine.
|
# Create the LLMEngine.
|
||||||
return cls(
|
return cls(
|
||||||
vllm_config=vllm_config,
|
vllm_config=vllm_config,
|
||||||
executor_class=Executor.get_class(vllm_config),
|
executor_class=Executor.get_class(vllm_config),
|
||||||
start_engine_loop=start_engine_loop,
|
start_engine_loop=start_engine_loop,
|
||||||
|
stat_loggers=stat_loggers,
|
||||||
log_requests=not disable_log_requests,
|
log_requests=not disable_log_requests,
|
||||||
log_stats=not disable_log_stats,
|
log_stats=not disable_log_stats,
|
||||||
usage_context=usage_context,
|
usage_context=usage_context,
|
||||||
@ -151,6 +163,7 @@ class AsyncLLM(EngineClient):
|
|||||||
engine_args: AsyncEngineArgs,
|
engine_args: AsyncEngineArgs,
|
||||||
start_engine_loop: bool = True,
|
start_engine_loop: bool = True,
|
||||||
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
|
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
|
||||||
|
stat_loggers: Optional[list[StatLoggerFactory]] = None,
|
||||||
) -> "AsyncLLM":
|
) -> "AsyncLLM":
|
||||||
"""Create an AsyncLLM from the EngineArgs."""
|
"""Create an AsyncLLM from the EngineArgs."""
|
||||||
|
|
||||||
@ -166,6 +179,7 @@ class AsyncLLM(EngineClient):
|
|||||||
log_stats=not engine_args.disable_log_stats,
|
log_stats=not engine_args.disable_log_stats,
|
||||||
start_engine_loop=start_engine_loop,
|
start_engine_loop=start_engine_loop,
|
||||||
usage_context=usage_context,
|
usage_context=usage_context,
|
||||||
|
stat_loggers=stat_loggers,
|
||||||
)
|
)
|
||||||
|
|
||||||
def __del__(self):
|
def __del__(self):
|
||||||
|
|||||||
@ -10,7 +10,6 @@ import vllm.envs as envs
|
|||||||
from vllm.config import ParallelConfig, VllmConfig
|
from vllm.config import ParallelConfig, VllmConfig
|
||||||
from vllm.distributed import stateless_destroy_torch_distributed_process_group
|
from vllm.distributed import stateless_destroy_torch_distributed_process_group
|
||||||
from vllm.engine.arg_utils import EngineArgs
|
from vllm.engine.arg_utils import EngineArgs
|
||||||
from vllm.engine.metrics_types import StatLoggerBase
|
|
||||||
from vllm.inputs import PromptType
|
from vllm.inputs import PromptType
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
@ -28,6 +27,7 @@ from vllm.v1.engine.output_processor import OutputProcessor
|
|||||||
from vllm.v1.engine.parallel_sampling import ParentRequest
|
from vllm.v1.engine.parallel_sampling import ParentRequest
|
||||||
from vllm.v1.engine.processor import Processor
|
from vllm.v1.engine.processor import Processor
|
||||||
from vllm.v1.executor.abstract import Executor
|
from vllm.v1.executor.abstract import Executor
|
||||||
|
from vllm.v1.metrics.loggers import StatLoggerFactory
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -43,7 +43,7 @@ class LLMEngine:
|
|||||||
executor_class: type[Executor],
|
executor_class: type[Executor],
|
||||||
log_stats: bool,
|
log_stats: bool,
|
||||||
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
|
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
|
||||||
stat_loggers: Optional[dict[str, StatLoggerBase]] = None,
|
stat_loggers: Optional[list[StatLoggerFactory]] = None,
|
||||||
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
|
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
|
||||||
use_cached_outputs: bool = False,
|
use_cached_outputs: bool = False,
|
||||||
multiprocess_mode: bool = False,
|
multiprocess_mode: bool = False,
|
||||||
@ -55,6 +55,11 @@ class LLMEngine:
|
|||||||
"LLMEngine.from_vllm_config(...) or explicitly set "
|
"LLMEngine.from_vllm_config(...) or explicitly set "
|
||||||
"VLLM_USE_V1=0 or 1 and report this issue on Github.")
|
"VLLM_USE_V1=0 or 1 and report this issue on Github.")
|
||||||
|
|
||||||
|
if stat_loggers is not None:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Passing StatLoggers to LLMEngine in V1 is not yet supported. "
|
||||||
|
"Set VLLM_USE_V1=0 and file and issue on Github.")
|
||||||
|
|
||||||
self.vllm_config = vllm_config
|
self.vllm_config = vllm_config
|
||||||
self.model_config = vllm_config.model_config
|
self.model_config = vllm_config.model_config
|
||||||
self.cache_config = vllm_config.cache_config
|
self.cache_config = vllm_config.cache_config
|
||||||
@ -101,14 +106,9 @@ class LLMEngine:
|
|||||||
cls,
|
cls,
|
||||||
vllm_config: VllmConfig,
|
vllm_config: VllmConfig,
|
||||||
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
|
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
|
||||||
stat_loggers: Optional[dict[str, StatLoggerBase]] = None,
|
stat_loggers: Optional[list[StatLoggerFactory]] = None,
|
||||||
disable_log_stats: bool = False,
|
disable_log_stats: bool = False,
|
||||||
) -> "LLMEngine":
|
) -> "LLMEngine":
|
||||||
if stat_loggers is not None:
|
|
||||||
raise NotImplementedError(
|
|
||||||
"Passing StatLoggers to V1 is not yet supported. "
|
|
||||||
"Set VLLM_USE_V1=0 and file and issue on Github.")
|
|
||||||
|
|
||||||
return cls(vllm_config=vllm_config,
|
return cls(vllm_config=vllm_config,
|
||||||
executor_class=Executor.get_class(vllm_config),
|
executor_class=Executor.get_class(vllm_config),
|
||||||
log_stats=(not disable_log_stats),
|
log_stats=(not disable_log_stats),
|
||||||
@ -121,7 +121,7 @@ class LLMEngine:
|
|||||||
cls,
|
cls,
|
||||||
engine_args: EngineArgs,
|
engine_args: EngineArgs,
|
||||||
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
|
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
|
||||||
stat_loggers: Optional[dict[str, StatLoggerBase]] = None,
|
stat_loggers: Optional[list[StatLoggerFactory]] = None,
|
||||||
enable_multiprocessing: bool = False,
|
enable_multiprocessing: bool = False,
|
||||||
) -> "LLMEngine":
|
) -> "LLMEngine":
|
||||||
"""Creates an LLM engine from the engine arguments."""
|
"""Creates an LLM engine from the engine arguments."""
|
||||||
|
|||||||
@ -1,8 +1,9 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
import logging
|
||||||
import time
|
import time
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Optional
|
from typing import Callable, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import prometheus_client
|
import prometheus_client
|
||||||
@ -18,8 +19,20 @@ logger = init_logger(__name__)
|
|||||||
|
|
||||||
_LOCAL_LOGGING_INTERVAL_SEC = 5.0
|
_LOCAL_LOGGING_INTERVAL_SEC = 5.0
|
||||||
|
|
||||||
|
StatLoggerFactory = Callable[[VllmConfig, int], "StatLoggerBase"]
|
||||||
|
|
||||||
|
|
||||||
class StatLoggerBase(ABC):
|
class StatLoggerBase(ABC):
|
||||||
|
"""Interface for logging metrics.
|
||||||
|
|
||||||
|
API users may define custom loggers that implement this interface.
|
||||||
|
However, note that the `SchedulerStats` and `IterationStats` classes
|
||||||
|
are not considered stable interfaces and may change in future versions.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def __init__(self, vllm_config: VllmConfig, engine_index: int = 0):
|
||||||
|
...
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def record(self, scheduler_stats: SchedulerStats,
|
def record(self, scheduler_stats: SchedulerStats,
|
||||||
@ -32,7 +45,7 @@ class StatLoggerBase(ABC):
|
|||||||
|
|
||||||
class LoggingStatLogger(StatLoggerBase):
|
class LoggingStatLogger(StatLoggerBase):
|
||||||
|
|
||||||
def __init__(self, engine_index: int = 0):
|
def __init__(self, vllm_config: VllmConfig, engine_index: int = 0):
|
||||||
self.engine_index = engine_index
|
self.engine_index = engine_index
|
||||||
self._reset(time.monotonic())
|
self._reset(time.monotonic())
|
||||||
self.last_scheduler_stats = SchedulerStats()
|
self.last_scheduler_stats = SchedulerStats()
|
||||||
@ -462,3 +475,31 @@ def build_cudagraph_buckets(vllm_config: VllmConfig) -> list[int]:
|
|||||||
return buckets
|
return buckets
|
||||||
else:
|
else:
|
||||||
return [1, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8096]
|
return [1, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8096]
|
||||||
|
|
||||||
|
|
||||||
|
def setup_default_loggers(
|
||||||
|
vllm_config: VllmConfig,
|
||||||
|
log_stats: bool,
|
||||||
|
engine_num: int,
|
||||||
|
custom_stat_loggers: Optional[list[StatLoggerFactory]] = None,
|
||||||
|
) -> list[list[StatLoggerBase]]:
|
||||||
|
"""Setup logging and prometheus metrics."""
|
||||||
|
if not log_stats:
|
||||||
|
return []
|
||||||
|
|
||||||
|
factories: list[StatLoggerFactory]
|
||||||
|
if custom_stat_loggers is not None:
|
||||||
|
factories = custom_stat_loggers
|
||||||
|
else:
|
||||||
|
factories = [PrometheusStatLogger]
|
||||||
|
if logger.isEnabledFor(logging.INFO):
|
||||||
|
factories.append(LoggingStatLogger)
|
||||||
|
|
||||||
|
stat_loggers: list[list[StatLoggerBase]] = []
|
||||||
|
for i in range(engine_num):
|
||||||
|
per_engine_stat_loggers: list[StatLoggerBase] = []
|
||||||
|
for logger_factory in factories:
|
||||||
|
per_engine_stat_loggers.append(logger_factory(vllm_config, i))
|
||||||
|
stat_loggers.append(per_engine_stat_loggers)
|
||||||
|
|
||||||
|
return stat_loggers
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user