mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-23 06:17:52 +08:00
add aggregator interface and abstract common logic
Signed-off-by: Lu Fang <fanglu@fb.com>
This commit is contained in:
parent
a46e279909
commit
98d535eb4f
@ -9,7 +9,7 @@ from vllm.engine.arg_utils import AsyncEngineArgs
|
|||||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||||
from vllm.outputs import RequestOutput
|
from vllm.outputs import RequestOutput
|
||||||
from vllm.sampling_params import SamplingParams
|
from vllm.sampling_params import SamplingParams
|
||||||
from vllm.v1.metrics.loggers import GlobalStatLogger, LoggingStatLogger
|
from vllm.v1.metrics.loggers import AggregatedStatLogger, LoggingStatLogger
|
||||||
|
|
||||||
"""
|
"""
|
||||||
To run this example, run the following commands simultaneously with
|
To run this example, run the following commands simultaneously with
|
||||||
@ -53,15 +53,10 @@ async def main():
|
|||||||
def per_engine_logger_factory(config: VllmConfig, rank: int) -> LoggingStatLogger:
|
def per_engine_logger_factory(config: VllmConfig, rank: int) -> LoggingStatLogger:
|
||||||
return LoggingStatLogger(config, rank)
|
return LoggingStatLogger(config, rank)
|
||||||
|
|
||||||
def global_logger_factory(
|
|
||||||
config: VllmConfig, engine_idxs: Optional[list[int]]
|
|
||||||
) -> GlobalStatLogger:
|
|
||||||
return GlobalStatLogger(config, engine_indexes=engine_idxs)
|
|
||||||
|
|
||||||
engine_client = AsyncLLMEngine.from_engine_args(
|
engine_client = AsyncLLMEngine.from_engine_args(
|
||||||
engine_args,
|
engine_args,
|
||||||
stat_loggers=[per_engine_logger_factory],
|
# Example: Using both regular loggers and aggregated logger
|
||||||
stat_logger_global=global_logger_factory,
|
stat_loggers=[per_engine_logger_factory, AggregatedStatLogger],
|
||||||
)
|
)
|
||||||
stop_logging_event = threading.Event()
|
stop_logging_event = threading.Event()
|
||||||
logging_thread = threading.Thread(
|
logging_thread = threading.Thread(
|
||||||
|
|||||||
@ -18,7 +18,7 @@ from vllm.platforms import current_platform
|
|||||||
from vllm.sampling_params import RequestOutputKind
|
from vllm.sampling_params import RequestOutputKind
|
||||||
from vllm.utils import set_default_torch_num_threads
|
from vllm.utils import set_default_torch_num_threads
|
||||||
from vllm.v1.engine.async_llm import AsyncLLM
|
from vllm.v1.engine.async_llm import AsyncLLM
|
||||||
from vllm.v1.metrics.loggers import GlobalStatLogger, LoggingStatLogger
|
from vllm.v1.metrics.loggers import AggregatedStatLogger, LoggingStatLogger
|
||||||
|
|
||||||
if not current_platform.is_cuda():
|
if not current_platform.is_cuda():
|
||||||
pytest.skip(reason="V1 currently only supported on CUDA.",
|
pytest.skip(reason="V1 currently only supported on CUDA.",
|
||||||
@ -389,7 +389,7 @@ class MockLoggingStatLogger(LoggingStatLogger):
|
|||||||
self.log = MagicMock()
|
self.log = MagicMock()
|
||||||
|
|
||||||
|
|
||||||
class MockGlobalStatLogger(GlobalStatLogger):
|
class MockAggregatedStatLogger(AggregatedStatLogger):
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
vllm_config: VllmConfig,
|
vllm_config: VllmConfig,
|
||||||
@ -425,8 +425,8 @@ async def test_customize_loggers(monkeypatch):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_customize_global_loggers(monkeypatch):
|
async def test_customize_aggregated_loggers(monkeypatch):
|
||||||
"""Test that we can customize the loggers.
|
"""Test that we can customize the aggregated loggers.
|
||||||
If a customized logger is provided at the init, it should
|
If a customized logger is provided at the init, it should
|
||||||
be added to the default loggers.
|
be added to the default loggers.
|
||||||
"""
|
"""
|
||||||
@ -437,8 +437,7 @@ async def test_customize_global_loggers(monkeypatch):
|
|||||||
with set_default_torch_num_threads(1):
|
with set_default_torch_num_threads(1):
|
||||||
engine = AsyncLLM.from_engine_args(
|
engine = AsyncLLM.from_engine_args(
|
||||||
TEXT_ENGINE_ARGS,
|
TEXT_ENGINE_ARGS,
|
||||||
stat_loggers=[MockLoggingStatLogger],
|
stat_loggers=[MockLoggingStatLogger, MockAggregatedStatLogger],
|
||||||
stat_logger_global=MockGlobalStatLogger,
|
|
||||||
)
|
)
|
||||||
after.callback(engine.shutdown)
|
after.callback(engine.shutdown)
|
||||||
|
|
||||||
@ -448,9 +447,9 @@ async def test_customize_global_loggers(monkeypatch):
|
|||||||
assert len(stat_loggers) == 1
|
assert len(stat_loggers) == 1
|
||||||
assert len(
|
assert len(
|
||||||
stat_loggers[0]) == 2 # LoggingStatLogger + MockLoggingStatLogger
|
stat_loggers[0]) == 2 # LoggingStatLogger + MockLoggingStatLogger
|
||||||
global_logger = engine.logger_manager.global_logger
|
aggregated_loggers = engine.logger_manager.aggregated_loggers
|
||||||
assert global_logger is not None
|
assert len(aggregated_loggers) == 1
|
||||||
global_logger.log.assert_called_once()
|
aggregated_loggers[0].log.assert_called_once()
|
||||||
stat_loggers[0][0].log.assert_called_once()
|
stat_loggers[0][0].log.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -42,8 +42,7 @@ from vllm.v1.engine.output_processor import (OutputProcessor,
|
|||||||
from vllm.v1.engine.parallel_sampling import ParentRequest
|
from vllm.v1.engine.parallel_sampling import ParentRequest
|
||||||
from vllm.v1.engine.processor import Processor
|
from vllm.v1.engine.processor import Processor
|
||||||
from vllm.v1.executor.abstract import Executor
|
from vllm.v1.executor.abstract import Executor
|
||||||
from vllm.v1.metrics.loggers import (GlobalStatLoggerFactory,
|
from vllm.v1.metrics.loggers import StatLoggerFactory, StatLoggerManager
|
||||||
StatLoggerFactory, StatLoggerManager)
|
|
||||||
from vllm.v1.metrics.prometheus import shutdown_prometheus
|
from vllm.v1.metrics.prometheus import shutdown_prometheus
|
||||||
from vllm.v1.metrics.stats import IterationStats
|
from vllm.v1.metrics.stats import IterationStats
|
||||||
|
|
||||||
@ -63,7 +62,6 @@ class AsyncLLM(EngineClient):
|
|||||||
log_requests: bool = True,
|
log_requests: bool = True,
|
||||||
start_engine_loop: bool = True,
|
start_engine_loop: bool = True,
|
||||||
stat_loggers: Optional[list[StatLoggerFactory]] = None,
|
stat_loggers: Optional[list[StatLoggerFactory]] = None,
|
||||||
stat_logger_global: Optional[GlobalStatLoggerFactory] = None,
|
|
||||||
client_addresses: Optional[dict[str, str]] = None,
|
client_addresses: Optional[dict[str, str]] = None,
|
||||||
client_count: int = 1,
|
client_count: int = 1,
|
||||||
client_index: int = 0,
|
client_index: int = 0,
|
||||||
@ -103,11 +101,8 @@ class AsyncLLM(EngineClient):
|
|||||||
self.observability_config = vllm_config.observability_config
|
self.observability_config = vllm_config.observability_config
|
||||||
self.log_requests = log_requests
|
self.log_requests = log_requests
|
||||||
|
|
||||||
self.log_stats = log_stats or (stat_loggers
|
self.log_stats = log_stats or (stat_loggers is not None)
|
||||||
is not None) or (stat_logger_global
|
if not log_stats and stat_loggers is not None:
|
||||||
is not None)
|
|
||||||
if (not log_stats and stat_loggers is not None
|
|
||||||
or stat_logger_global is not None):
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"AsyncLLM created with log_stats=False and non-empty custom "
|
"AsyncLLM created with log_stats=False and non-empty custom "
|
||||||
"logger list; enabling logging without default stat loggers")
|
"logger list; enabling logging without default stat loggers")
|
||||||
@ -152,7 +147,6 @@ class AsyncLLM(EngineClient):
|
|||||||
vllm_config=vllm_config,
|
vllm_config=vllm_config,
|
||||||
engine_idxs=self.engine_core.engine_ranks_managed,
|
engine_idxs=self.engine_core.engine_ranks_managed,
|
||||||
custom_stat_loggers=stat_loggers,
|
custom_stat_loggers=stat_loggers,
|
||||||
custom_stat_logger_global=stat_logger_global,
|
|
||||||
enable_default_loggers=log_stats,
|
enable_default_loggers=log_stats,
|
||||||
client_count=client_count,
|
client_count=client_count,
|
||||||
)
|
)
|
||||||
@ -195,7 +189,6 @@ class AsyncLLM(EngineClient):
|
|||||||
start_engine_loop: bool = True,
|
start_engine_loop: bool = True,
|
||||||
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
|
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
|
||||||
stat_loggers: Optional[list[StatLoggerFactory]] = None,
|
stat_loggers: Optional[list[StatLoggerFactory]] = None,
|
||||||
stat_logger_global: Optional[GlobalStatLoggerFactory] = None,
|
|
||||||
enable_log_requests: bool = False,
|
enable_log_requests: bool = False,
|
||||||
disable_log_stats: bool = False,
|
disable_log_stats: bool = False,
|
||||||
client_addresses: Optional[dict[str, str]] = None,
|
client_addresses: Optional[dict[str, str]] = None,
|
||||||
@ -216,7 +209,6 @@ class AsyncLLM(EngineClient):
|
|||||||
executor_class=Executor.get_class(vllm_config),
|
executor_class=Executor.get_class(vllm_config),
|
||||||
start_engine_loop=start_engine_loop,
|
start_engine_loop=start_engine_loop,
|
||||||
stat_loggers=stat_loggers,
|
stat_loggers=stat_loggers,
|
||||||
stat_logger_global=stat_logger_global,
|
|
||||||
log_requests=enable_log_requests,
|
log_requests=enable_log_requests,
|
||||||
log_stats=not disable_log_stats,
|
log_stats=not disable_log_stats,
|
||||||
usage_context=usage_context,
|
usage_context=usage_context,
|
||||||
@ -232,7 +224,6 @@ class AsyncLLM(EngineClient):
|
|||||||
start_engine_loop: bool = True,
|
start_engine_loop: bool = True,
|
||||||
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
|
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
|
||||||
stat_loggers: Optional[list[StatLoggerFactory]] = None,
|
stat_loggers: Optional[list[StatLoggerFactory]] = None,
|
||||||
stat_logger_global: Optional[GlobalStatLoggerFactory] = None,
|
|
||||||
) -> "AsyncLLM":
|
) -> "AsyncLLM":
|
||||||
"""Create an AsyncLLM from the EngineArgs."""
|
"""Create an AsyncLLM from the EngineArgs."""
|
||||||
|
|
||||||
@ -249,7 +240,6 @@ class AsyncLLM(EngineClient):
|
|||||||
start_engine_loop=start_engine_loop,
|
start_engine_loop=start_engine_loop,
|
||||||
usage_context=usage_context,
|
usage_context=usage_context,
|
||||||
stat_loggers=stat_loggers,
|
stat_loggers=stat_loggers,
|
||||||
stat_logger_global=stat_logger_global,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def __del__(self):
|
def __del__(self):
|
||||||
|
|||||||
@ -18,9 +18,9 @@ from vllm.v1.spec_decode.metrics import SpecDecodingLogging, SpecDecodingProm
|
|||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
StatLoggerFactory = Callable[[VllmConfig, int], "StatLoggerBase"]
|
PerEngineStatLoggerFactory = Callable[[VllmConfig, int], "StatLoggerBase"]
|
||||||
GlobalStatLoggerFactory = Callable[[VllmConfig, list[int]],
|
StatLoggerFactory = Union[PerEngineStatLoggerFactory,
|
||||||
"GlobalStatLoggerBase"]
|
type["AggregatedStatLoggerBase"]]
|
||||||
|
|
||||||
|
|
||||||
class StatLoggerBase(ABC):
|
class StatLoggerBase(ABC):
|
||||||
@ -50,8 +50,9 @@ class StatLoggerBase(ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class GlobalStatLoggerBase(StatLoggerBase):
|
class AggregatedStatLoggerBase(StatLoggerBase):
|
||||||
"""Interface for logging metrics for multiple engines."""
|
"""Abstract base class for loggers that
|
||||||
|
aggregates statistics across multiple engines."""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def __init__(self, vllm_config: VllmConfig,
|
def __init__(self, vllm_config: VllmConfig,
|
||||||
@ -72,6 +73,7 @@ class LoggingStatLogger(StatLoggerBase):
|
|||||||
self.spec_decoding_logging = SpecDecodingLogging()
|
self.spec_decoding_logging = SpecDecodingLogging()
|
||||||
self.last_prompt_throughput: float = 0.0
|
self.last_prompt_throughput: float = 0.0
|
||||||
self.last_generation_throughput: float = 0.0
|
self.last_generation_throughput: float = 0.0
|
||||||
|
self.engine_is_idle = False
|
||||||
|
|
||||||
def _reset(self, now):
|
def _reset(self, now):
|
||||||
self.last_log_time = now
|
self.last_log_time = now
|
||||||
@ -111,25 +113,25 @@ class LoggingStatLogger(StatLoggerBase):
|
|||||||
|
|
||||||
self.last_scheduler_stats = scheduler_stats
|
self.last_scheduler_stats = scheduler_stats
|
||||||
|
|
||||||
def log(self):
|
def get_log_stats(self):
|
||||||
now = time.monotonic()
|
now = time.monotonic()
|
||||||
prompt_throughput = self._get_throughput(self.num_prompt_tokens, now)
|
prompt_throughput = self._get_throughput(self.num_prompt_tokens, now)
|
||||||
generation_throughput = self._get_throughput(
|
generation_throughput = self._get_throughput(
|
||||||
self.num_generation_tokens, now)
|
self.num_generation_tokens, now)
|
||||||
|
|
||||||
self._reset(now)
|
self._reset(now)
|
||||||
|
|
||||||
scheduler_stats = self.last_scheduler_stats
|
|
||||||
|
|
||||||
log_fn = logger.info
|
|
||||||
if not any(
|
|
||||||
(prompt_throughput, generation_throughput,
|
|
||||||
self.last_prompt_throughput, self.last_generation_throughput)):
|
|
||||||
# Avoid log noise on an idle production system
|
|
||||||
log_fn = logger.debug
|
|
||||||
self.last_generation_throughput = generation_throughput
|
self.last_generation_throughput = generation_throughput
|
||||||
self.last_prompt_throughput = prompt_throughput
|
self.last_prompt_throughput = prompt_throughput
|
||||||
|
self.engine_is_idle = not any(
|
||||||
|
(prompt_throughput, generation_throughput,
|
||||||
|
self.last_prompt_throughput, self.last_generation_throughput))
|
||||||
|
|
||||||
|
def log(self):
|
||||||
|
self.get_log_stats()
|
||||||
|
log_fn = logger.info
|
||||||
|
if self.engine_is_idle:
|
||||||
|
# Avoid log noise on an idle production system
|
||||||
|
log_fn = logger.debug
|
||||||
# Format and print output.
|
# Format and print output.
|
||||||
log_fn(
|
log_fn(
|
||||||
"Engine %03d: "
|
"Engine %03d: "
|
||||||
@ -139,11 +141,11 @@ class LoggingStatLogger(StatLoggerBase):
|
|||||||
"GPU KV cache usage: %.1f%%, "
|
"GPU KV cache usage: %.1f%%, "
|
||||||
"Prefix cache hit rate: %.1f%%",
|
"Prefix cache hit rate: %.1f%%",
|
||||||
self.engine_index,
|
self.engine_index,
|
||||||
prompt_throughput,
|
self.last_prompt_throughput,
|
||||||
generation_throughput,
|
self.last_generation_throughput,
|
||||||
scheduler_stats.num_running_reqs,
|
self.last_scheduler_stats.num_running_reqs,
|
||||||
scheduler_stats.num_waiting_reqs,
|
self.last_scheduler_stats.num_waiting_reqs,
|
||||||
scheduler_stats.kv_cache_usage * 100,
|
self.last_scheduler_stats.kv_cache_usage * 100,
|
||||||
self.prefix_caching_metrics.hit_rate * 100,
|
self.prefix_caching_metrics.hit_rate * 100,
|
||||||
)
|
)
|
||||||
self.spec_decoding_logging.log(log_fn=log_fn)
|
self.spec_decoding_logging.log(log_fn=log_fn)
|
||||||
@ -156,37 +158,34 @@ class LoggingStatLogger(StatLoggerBase):
|
|||||||
self.vllm_config.cache_config.num_gpu_blocks)
|
self.vllm_config.cache_config.num_gpu_blocks)
|
||||||
|
|
||||||
|
|
||||||
class GlobalStatLogger(LoggingStatLogger, GlobalStatLoggerBase):
|
class AggregatedStatLogger(LoggingStatLogger, AggregatedStatLoggerBase):
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
vllm_config: VllmConfig,
|
vllm_config: VllmConfig,
|
||||||
engine_indexes: Optional[list[int]] = None):
|
engine_idxs: Optional[list[int]] = None):
|
||||||
super().__init__(vllm_config, -1)
|
if engine_idxs is None:
|
||||||
if engine_indexes is None:
|
engine_idxs = [0]
|
||||||
engine_indexes = [0]
|
self.engine_idxs = engine_idxs
|
||||||
self.engine_index = -1
|
LoggingStatLogger.__init__(self, vllm_config, engine_index=-1)
|
||||||
self.engine_indexes = engine_indexes
|
|
||||||
self.vllm_config = vllm_config
|
def record(
|
||||||
|
self,
|
||||||
|
scheduler_stats: Optional[SchedulerStats],
|
||||||
|
iteration_stats: Optional[IterationStats],
|
||||||
|
engine_idx: int = 0,
|
||||||
|
):
|
||||||
|
if engine_idx not in self.engine_idxs:
|
||||||
|
logger.warning("Unexpected engine_idx: %d", engine_idx)
|
||||||
|
return
|
||||||
|
LoggingStatLogger.record(self, scheduler_stats, iteration_stats,
|
||||||
|
engine_idx)
|
||||||
|
|
||||||
def log(self):
|
def log(self):
|
||||||
now = time.monotonic()
|
self.get_log_stats()
|
||||||
prompt_throughput = self._get_throughput(self.num_prompt_tokens, now)
|
|
||||||
generation_throughput = self._get_throughput(
|
|
||||||
self.num_generation_tokens, now)
|
|
||||||
|
|
||||||
self._reset(now)
|
|
||||||
|
|
||||||
scheduler_stats = self.last_scheduler_stats
|
|
||||||
|
|
||||||
log_fn = logger.info
|
log_fn = logger.info
|
||||||
if not any(
|
if self.engine_is_idle:
|
||||||
(prompt_throughput, generation_throughput,
|
|
||||||
self.last_prompt_throughput, self.last_generation_throughput)):
|
|
||||||
# Avoid log noise on an idle production system
|
# Avoid log noise on an idle production system
|
||||||
log_fn = logger.debug
|
log_fn = logger.debug
|
||||||
self.last_generation_throughput = generation_throughput
|
|
||||||
self.last_prompt_throughput = prompt_throughput
|
|
||||||
|
|
||||||
# Format and print output.
|
# Format and print output.
|
||||||
log_fn(
|
log_fn(
|
||||||
"%s Engines Aggregated: "
|
"%s Engines Aggregated: "
|
||||||
@ -195,12 +194,12 @@ class GlobalStatLogger(LoggingStatLogger, GlobalStatLoggerBase):
|
|||||||
"Running: %d reqs, Waiting: %d reqs, "
|
"Running: %d reqs, Waiting: %d reqs, "
|
||||||
"GPU KV cache usage: %.1f%%, "
|
"GPU KV cache usage: %.1f%%, "
|
||||||
"Prefix cache hit rate: %.1f%%",
|
"Prefix cache hit rate: %.1f%%",
|
||||||
len(self.engine_indexes),
|
len(self.engine_idxs),
|
||||||
prompt_throughput,
|
self.last_prompt_throughput,
|
||||||
generation_throughput,
|
self.last_generation_throughput,
|
||||||
scheduler_stats.num_running_reqs,
|
self.last_scheduler_stats.num_running_reqs,
|
||||||
scheduler_stats.num_waiting_reqs,
|
self.last_scheduler_stats.num_waiting_reqs,
|
||||||
scheduler_stats.kv_cache_usage * 100,
|
self.last_scheduler_stats.kv_cache_usage * 100,
|
||||||
self.prefix_caching_metrics.hit_rate * 100,
|
self.prefix_caching_metrics.hit_rate * 100,
|
||||||
)
|
)
|
||||||
self.spec_decoding_logging.log(log_fn=log_fn)
|
self.spec_decoding_logging.log(log_fn=log_fn)
|
||||||
@ -209,11 +208,11 @@ class GlobalStatLogger(LoggingStatLogger, GlobalStatLoggerBase):
|
|||||||
if self.vllm_config.cache_config.num_gpu_blocks:
|
if self.vllm_config.cache_config.num_gpu_blocks:
|
||||||
logger.info(
|
logger.info(
|
||||||
"%d Engines: vllm cache_config_info with initialization "
|
"%d Engines: vllm cache_config_info with initialization "
|
||||||
"after num_gpu_blocks is: %d", len(self.engine_indexes),
|
"after num_gpu_blocks is: %d", len(self.engine_idxs),
|
||||||
self.vllm_config.cache_config.num_gpu_blocks)
|
self.vllm_config.cache_config.num_gpu_blocks)
|
||||||
|
|
||||||
|
|
||||||
class PrometheusStatLogger(StatLoggerBase):
|
class PrometheusStatLogger(AggregatedStatLoggerBase):
|
||||||
_gauge_cls = prometheus_client.Gauge
|
_gauge_cls = prometheus_client.Gauge
|
||||||
_counter_cls = prometheus_client.Counter
|
_counter_cls = prometheus_client.Counter
|
||||||
_histogram_cls = prometheus_client.Histogram
|
_histogram_cls = prometheus_client.Histogram
|
||||||
@ -723,7 +722,6 @@ class StatLoggerManager:
|
|||||||
vllm_config: VllmConfig,
|
vllm_config: VllmConfig,
|
||||||
engine_idxs: Optional[list[int]] = None,
|
engine_idxs: Optional[list[int]] = None,
|
||||||
custom_stat_loggers: Optional[list[StatLoggerFactory]] = None,
|
custom_stat_loggers: Optional[list[StatLoggerFactory]] = None,
|
||||||
custom_stat_logger_global: Optional[GlobalStatLoggerFactory] = None,
|
|
||||||
enable_default_loggers: bool = True,
|
enable_default_loggers: bool = True,
|
||||||
client_count: int = 1,
|
client_count: int = 1,
|
||||||
):
|
):
|
||||||
@ -743,28 +741,32 @@ class StatLoggerManager:
|
|||||||
|
|
||||||
# engine_idx: StatLogger
|
# engine_idx: StatLogger
|
||||||
self.per_engine_logger_dict: dict[int, list[StatLoggerBase]] = {}
|
self.per_engine_logger_dict: dict[int, list[StatLoggerBase]] = {}
|
||||||
prometheus_factory = PrometheusStatLogger
|
self.aggregated_loggers: list[AggregatedStatLoggerBase] = []
|
||||||
|
|
||||||
|
aggregated_loggers_factories = set()
|
||||||
for engine_idx in self.engine_idxs:
|
for engine_idx in self.engine_idxs:
|
||||||
loggers: list[StatLoggerBase] = []
|
loggers: list[StatLoggerBase] = []
|
||||||
for logger_factory in factories:
|
for logger_factory in factories:
|
||||||
# If we get a custom prometheus logger, use that
|
# If we get a custom prometheus logger or aggregated logger,
|
||||||
# instead. This is typically used for the ray case.
|
# We initialize it separately with all engine idxs.
|
||||||
if (isinstance(logger_factory, type)
|
# A custom prometheus logger is typically used for the ray.
|
||||||
and issubclass(logger_factory, PrometheusStatLogger)):
|
if (isinstance(logger_factory, type) and issubclass(
|
||||||
prometheus_factory = logger_factory
|
logger_factory, AggregatedStatLoggerBase)):
|
||||||
continue
|
aggregated_loggers_factories.add(logger_factory)
|
||||||
loggers.append(logger_factory(vllm_config,
|
else:
|
||||||
engine_idx)) # type: ignore
|
loggers.append(logger_factory(vllm_config,
|
||||||
|
engine_idx)) # type: ignore
|
||||||
self.per_engine_logger_dict[engine_idx] = loggers
|
self.per_engine_logger_dict[engine_idx] = loggers
|
||||||
|
# If no custom aggregated logger is provide,
|
||||||
# For Prometheus or custom global logger,
|
# we by default use PrometheusStatLogger
|
||||||
|
if not aggregated_loggers_factories:
|
||||||
|
aggregated_loggers_factories.add(PrometheusStatLogger)
|
||||||
|
# For custom aggregated logger(or default Prometheus Logger)
|
||||||
# need to share the metrics between EngineCores.
|
# need to share the metrics between EngineCores.
|
||||||
# Each EngineCore's metrics are expressed as a unique label.
|
# Each EngineCore's metrics are expressed as a unique label.
|
||||||
self.prometheus_logger = prometheus_factory(vllm_config, engine_idxs)
|
for aggregated_loggers_factory in aggregated_loggers_factories:
|
||||||
self.global_logger: Optional[StatLoggerBase] = None
|
self.aggregated_loggers.append(
|
||||||
if custom_stat_logger_global is not None:
|
aggregated_loggers_factory(vllm_config, engine_idxs))
|
||||||
self.global_logger = custom_stat_logger_global(
|
|
||||||
vllm_config, self.engine_idxs)
|
|
||||||
|
|
||||||
def record(
|
def record(
|
||||||
self,
|
self,
|
||||||
@ -778,24 +780,19 @@ class StatLoggerManager:
|
|||||||
per_engine_loggers = self.per_engine_logger_dict[engine_idx]
|
per_engine_loggers = self.per_engine_logger_dict[engine_idx]
|
||||||
for logger in per_engine_loggers:
|
for logger in per_engine_loggers:
|
||||||
logger.record(scheduler_stats, iteration_stats, engine_idx)
|
logger.record(scheduler_stats, iteration_stats, engine_idx)
|
||||||
|
for logger in self.aggregated_loggers:
|
||||||
self.prometheus_logger.record(scheduler_stats, iteration_stats,
|
logger.record(scheduler_stats, iteration_stats, engine_idx)
|
||||||
engine_idx)
|
|
||||||
if self.global_logger is not None:
|
|
||||||
self.global_logger.record(scheduler_stats, iteration_stats,
|
|
||||||
engine_idx)
|
|
||||||
|
|
||||||
def log(self):
|
def log(self):
|
||||||
for per_engine_loggers in self.per_engine_logger_dict.values():
|
for per_engine_loggers in self.per_engine_logger_dict.values():
|
||||||
for logger in per_engine_loggers:
|
for logger in per_engine_loggers:
|
||||||
logger.log()
|
logger.log()
|
||||||
if self.global_logger is not None:
|
for logger in self.aggregated_loggers:
|
||||||
self.global_logger.log()
|
logger.log()
|
||||||
|
|
||||||
def log_engine_initialized(self):
|
def log_engine_initialized(self):
|
||||||
self.prometheus_logger.log_engine_initialized()
|
for agg_logger in self.aggregated_loggers:
|
||||||
if self.global_logger is not None:
|
agg_logger.log_engine_initialized()
|
||||||
self.global_logger.log_engine_initialized()
|
|
||||||
for per_engine_loggers in self.per_engine_logger_dict.values():
|
for per_engine_loggers in self.per_engine_logger_dict.values():
|
||||||
for logger in per_engine_loggers:
|
for logger in per_engine_loggers:
|
||||||
logger.log_engine_initialized()
|
logger.log_engine_initialized()
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user