mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-16 14:07:13 +08:00
[Misc][DP] support customized aggregated logger for dp (#24354)
Signed-off-by: Lu Fang <fanglu@fb.com>
This commit is contained in:
parent
d8bebb008a
commit
8317f72354
@ -1,11 +1,13 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import threading
|
||||||
|
|
||||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||||
from vllm.outputs import RequestOutput
|
from vllm.outputs import RequestOutput
|
||||||
from vllm.sampling_params import SamplingParams
|
from vllm.sampling_params import SamplingParams
|
||||||
|
from vllm.v1.metrics.loggers import AggregatedLoggingStatLogger
|
||||||
|
|
||||||
"""
|
"""
|
||||||
To run this example, run the following commands simultaneously with
|
To run this example, run the following commands simultaneously with
|
||||||
@ -21,37 +23,64 @@ send a request to the instance with DP rank 1.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def _do_background_logging(engine, interval, stop_event):
|
||||||
|
try:
|
||||||
|
while not stop_event.is_set():
|
||||||
|
asyncio.run(engine.do_log_stats())
|
||||||
|
stop_event.wait(interval)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"vLLM background logging shutdown: {e}")
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
async def main():
|
async def main():
|
||||||
engine_args = AsyncEngineArgs(
|
engine_args = AsyncEngineArgs(
|
||||||
model="ibm-research/PowerMoE-3b",
|
model="ibm-research/PowerMoE-3b",
|
||||||
data_parallel_size=2,
|
data_parallel_size=2,
|
||||||
|
tensor_parallel_size=1,
|
||||||
dtype="auto",
|
dtype="auto",
|
||||||
max_model_len=2048,
|
max_model_len=2048,
|
||||||
data_parallel_address="127.0.0.1",
|
data_parallel_address="127.0.0.1",
|
||||||
data_parallel_rpc_port=62300,
|
data_parallel_rpc_port=62300,
|
||||||
data_parallel_size_local=1,
|
data_parallel_size_local=1,
|
||||||
enforce_eager=True,
|
enforce_eager=True,
|
||||||
|
enable_log_requests=True,
|
||||||
|
disable_custom_all_reduce=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
engine_client = AsyncLLMEngine.from_engine_args(engine_args)
|
engine_client = AsyncLLMEngine.from_engine_args(
|
||||||
|
engine_args,
|
||||||
|
# Example: Using aggregated logger
|
||||||
|
stat_loggers=[AggregatedLoggingStatLogger],
|
||||||
|
)
|
||||||
|
stop_logging_event = threading.Event()
|
||||||
|
logging_thread = threading.Thread(
|
||||||
|
target=_do_background_logging,
|
||||||
|
args=(engine_client, 5, stop_logging_event),
|
||||||
|
daemon=True,
|
||||||
|
)
|
||||||
|
logging_thread.start()
|
||||||
sampling_params = SamplingParams(
|
sampling_params = SamplingParams(
|
||||||
temperature=0.7,
|
temperature=0.7,
|
||||||
top_p=0.9,
|
top_p=0.9,
|
||||||
max_tokens=100,
|
max_tokens=100,
|
||||||
)
|
)
|
||||||
|
num_prompts = 10
|
||||||
|
for i in range(num_prompts):
|
||||||
|
prompt = "Who won the 2004 World Series?"
|
||||||
|
final_output: RequestOutput | None = None
|
||||||
|
async for output in engine_client.generate(
|
||||||
|
prompt=prompt,
|
||||||
|
sampling_params=sampling_params,
|
||||||
|
request_id=f"abcdef-{i}",
|
||||||
|
data_parallel_rank=1,
|
||||||
|
):
|
||||||
|
final_output = output
|
||||||
|
if final_output:
|
||||||
|
print(final_output.outputs[0].text)
|
||||||
|
|
||||||
prompt = "Who won the 2004 World Series?"
|
stop_logging_event.set()
|
||||||
final_output: RequestOutput | None = None
|
logging_thread.join()
|
||||||
async for output in engine_client.generate(
|
|
||||||
prompt=prompt,
|
|
||||||
sampling_params=sampling_params,
|
|
||||||
request_id="abcdef",
|
|
||||||
data_parallel_rank=1,
|
|
||||||
):
|
|
||||||
final_output = output
|
|
||||||
if final_output:
|
|
||||||
print(final_output.outputs[0].text)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@ -17,7 +17,12 @@ from vllm.platforms import current_platform
|
|||||||
from vllm.sampling_params import RequestOutputKind
|
from vllm.sampling_params import RequestOutputKind
|
||||||
from vllm.utils import set_default_torch_num_threads
|
from vllm.utils import set_default_torch_num_threads
|
||||||
from vllm.v1.engine.async_llm import AsyncLLM
|
from vllm.v1.engine.async_llm import AsyncLLM
|
||||||
from vllm.v1.metrics.loggers import LoggingStatLogger
|
from vllm.v1.metrics.loggers import (
|
||||||
|
AggregatedLoggingStatLogger,
|
||||||
|
LoggingStatLogger,
|
||||||
|
PerEngineStatLoggerAdapter,
|
||||||
|
PrometheusStatLogger,
|
||||||
|
)
|
||||||
|
|
||||||
if not current_platform.is_cuda():
|
if not current_platform.is_cuda():
|
||||||
pytest.skip(reason="V1 currently only supported on CUDA.", allow_module_level=True)
|
pytest.skip(reason="V1 currently only supported on CUDA.", allow_module_level=True)
|
||||||
@ -384,6 +389,12 @@ class MockLoggingStatLogger(LoggingStatLogger):
|
|||||||
self.log = MagicMock()
|
self.log = MagicMock()
|
||||||
|
|
||||||
|
|
||||||
|
class MockAggregatedStatLogger(AggregatedLoggingStatLogger):
|
||||||
|
def __init__(self, vllm_config: VllmConfig, engine_indexes: list[int]):
|
||||||
|
super().__init__(vllm_config, engine_indexes)
|
||||||
|
self.log = MagicMock()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_customize_loggers(monkeypatch):
|
async def test_customize_loggers(monkeypatch):
|
||||||
"""Test that we can customize the loggers.
|
"""Test that we can customize the loggers.
|
||||||
@ -401,10 +412,45 @@ async def test_customize_loggers(monkeypatch):
|
|||||||
|
|
||||||
await engine.do_log_stats()
|
await engine.do_log_stats()
|
||||||
|
|
||||||
stat_loggers = engine.logger_manager.per_engine_logger_dict
|
stat_loggers = engine.logger_manager.stat_loggers
|
||||||
assert len(stat_loggers) == 1
|
assert (
|
||||||
assert len(stat_loggers[0]) == 2 # LoggingStatLogger + MockLoggingStatLogger
|
len(stat_loggers) == 3
|
||||||
stat_loggers[0][0].log.assert_called_once()
|
) # MockLoggingStatLogger + LoggingStatLogger + Promethus Logger
|
||||||
|
print(f"{stat_loggers=}")
|
||||||
|
stat_loggers[0].per_engine_stat_loggers[0].log.assert_called_once()
|
||||||
|
assert isinstance(stat_loggers[1], PerEngineStatLoggerAdapter)
|
||||||
|
assert isinstance(stat_loggers[1].per_engine_stat_loggers[0], LoggingStatLogger)
|
||||||
|
assert isinstance(stat_loggers[2], PrometheusStatLogger)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_customize_aggregated_loggers(monkeypatch):
|
||||||
|
"""Test that we can customize the aggregated loggers.
|
||||||
|
If a customized logger is provided at the init, it should
|
||||||
|
be added to the default loggers.
|
||||||
|
"""
|
||||||
|
|
||||||
|
with monkeypatch.context() as m, ExitStack() as after:
|
||||||
|
m.setenv("VLLM_USE_V1", "1")
|
||||||
|
|
||||||
|
with set_default_torch_num_threads(1):
|
||||||
|
engine = AsyncLLM.from_engine_args(
|
||||||
|
TEXT_ENGINE_ARGS,
|
||||||
|
stat_loggers=[MockLoggingStatLogger, MockAggregatedStatLogger],
|
||||||
|
)
|
||||||
|
after.callback(engine.shutdown)
|
||||||
|
|
||||||
|
await engine.do_log_stats()
|
||||||
|
|
||||||
|
stat_loggers = engine.logger_manager.stat_loggers
|
||||||
|
assert len(stat_loggers) == 4
|
||||||
|
# MockLoggingStatLogger + MockAggregatedStatLogger
|
||||||
|
# + LoggingStatLogger + PrometheusStatLogger
|
||||||
|
stat_loggers[0].per_engine_stat_loggers[0].log.assert_called_once()
|
||||||
|
stat_loggers[1].log.assert_called_once()
|
||||||
|
assert isinstance(stat_loggers[2], PerEngineStatLoggerAdapter)
|
||||||
|
assert isinstance(stat_loggers[2].per_engine_stat_loggers[0], LoggingStatLogger)
|
||||||
|
assert isinstance(stat_loggers[3], PrometheusStatLogger)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(scope="module")
|
@pytest.mark.asyncio(scope="module")
|
||||||
|
|||||||
@ -54,7 +54,7 @@ async def test_async_llm_replace_default_loggers(log_stats_enabled_engine_args):
|
|||||||
engine = AsyncLLM.from_engine_args(
|
engine = AsyncLLM.from_engine_args(
|
||||||
log_stats_enabled_engine_args, stat_loggers=[RayPrometheusStatLogger]
|
log_stats_enabled_engine_args, stat_loggers=[RayPrometheusStatLogger]
|
||||||
)
|
)
|
||||||
assert isinstance(engine.logger_manager.prometheus_logger, RayPrometheusStatLogger)
|
assert isinstance(engine.logger_manager.stat_loggers[0], RayPrometheusStatLogger)
|
||||||
engine.shutdown()
|
engine.shutdown()
|
||||||
|
|
||||||
|
|
||||||
@ -73,9 +73,11 @@ async def test_async_llm_add_to_default_loggers(log_stats_enabled_engine_args):
|
|||||||
disabled_log_engine_args, stat_loggers=[DummyStatLogger]
|
disabled_log_engine_args, stat_loggers=[DummyStatLogger]
|
||||||
)
|
)
|
||||||
|
|
||||||
assert len(engine.logger_manager.per_engine_logger_dict[0]) == 1
|
assert len(engine.logger_manager.stat_loggers) == 2
|
||||||
|
assert len(engine.logger_manager.stat_loggers[0].per_engine_stat_loggers) == 1
|
||||||
assert isinstance(
|
assert isinstance(
|
||||||
engine.logger_manager.per_engine_logger_dict[0][0], DummyStatLogger
|
engine.logger_manager.stat_loggers[0].per_engine_stat_loggers[0],
|
||||||
|
DummyStatLogger,
|
||||||
)
|
)
|
||||||
|
|
||||||
# log_stats is still True, since custom stat loggers are used
|
# log_stats is still True, since custom stat loggers are used
|
||||||
|
|||||||
@ -410,6 +410,7 @@ class EngineArgs:
|
|||||||
max_logprobs: int = ModelConfig.max_logprobs
|
max_logprobs: int = ModelConfig.max_logprobs
|
||||||
logprobs_mode: LogprobsMode = ModelConfig.logprobs_mode
|
logprobs_mode: LogprobsMode = ModelConfig.logprobs_mode
|
||||||
disable_log_stats: bool = False
|
disable_log_stats: bool = False
|
||||||
|
aggregate_engine_logging: bool = False
|
||||||
revision: str | None = ModelConfig.revision
|
revision: str | None = ModelConfig.revision
|
||||||
code_revision: str | None = ModelConfig.code_revision
|
code_revision: str | None = ModelConfig.code_revision
|
||||||
rope_scaling: dict[str, Any] = get_field(ModelConfig, "rope_scaling")
|
rope_scaling: dict[str, Any] = get_field(ModelConfig, "rope_scaling")
|
||||||
@ -1043,6 +1044,12 @@ class EngineArgs:
|
|||||||
help="Disable logging statistics.",
|
help="Disable logging statistics.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--aggregate-engine-logging",
|
||||||
|
action="store_true",
|
||||||
|
help="Log aggregate rather than per-engine statistics "
|
||||||
|
"when using data parallelism.",
|
||||||
|
)
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
@ -239,6 +239,7 @@ async def build_async_engine_client_from_engine_args(
|
|||||||
vllm_config=vllm_config,
|
vllm_config=vllm_config,
|
||||||
usage_context=usage_context,
|
usage_context=usage_context,
|
||||||
enable_log_requests=engine_args.enable_log_requests,
|
enable_log_requests=engine_args.enable_log_requests,
|
||||||
|
aggregate_engine_logging=engine_args.aggregate_engine_logging,
|
||||||
disable_log_stats=engine_args.disable_log_stats,
|
disable_log_stats=engine_args.disable_log_stats,
|
||||||
client_addresses=client_config,
|
client_addresses=client_config,
|
||||||
client_count=client_count,
|
client_count=client_count,
|
||||||
|
|||||||
@ -57,6 +57,7 @@ class AsyncLLM(EngineClient):
|
|||||||
log_requests: bool = True,
|
log_requests: bool = True,
|
||||||
start_engine_loop: bool = True,
|
start_engine_loop: bool = True,
|
||||||
stat_loggers: list[StatLoggerFactory] | None = None,
|
stat_loggers: list[StatLoggerFactory] | None = None,
|
||||||
|
aggregate_engine_logging: bool = False,
|
||||||
client_addresses: dict[str, str] | None = None,
|
client_addresses: dict[str, str] | None = None,
|
||||||
client_count: int = 1,
|
client_count: int = 1,
|
||||||
client_index: int = 0,
|
client_index: int = 0,
|
||||||
@ -144,6 +145,7 @@ class AsyncLLM(EngineClient):
|
|||||||
custom_stat_loggers=stat_loggers,
|
custom_stat_loggers=stat_loggers,
|
||||||
enable_default_loggers=log_stats,
|
enable_default_loggers=log_stats,
|
||||||
client_count=client_count,
|
client_count=client_count,
|
||||||
|
aggregate_engine_logging=aggregate_engine_logging,
|
||||||
)
|
)
|
||||||
self.logger_manager.log_engine_initialized()
|
self.logger_manager.log_engine_initialized()
|
||||||
|
|
||||||
@ -187,6 +189,7 @@ class AsyncLLM(EngineClient):
|
|||||||
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
|
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
|
||||||
stat_loggers: list[StatLoggerFactory] | None = None,
|
stat_loggers: list[StatLoggerFactory] | None = None,
|
||||||
enable_log_requests: bool = False,
|
enable_log_requests: bool = False,
|
||||||
|
aggregate_engine_logging: bool = False,
|
||||||
disable_log_stats: bool = False,
|
disable_log_stats: bool = False,
|
||||||
client_addresses: dict[str, str] | None = None,
|
client_addresses: dict[str, str] | None = None,
|
||||||
client_count: int = 1,
|
client_count: int = 1,
|
||||||
@ -209,6 +212,7 @@ class AsyncLLM(EngineClient):
|
|||||||
stat_loggers=stat_loggers,
|
stat_loggers=stat_loggers,
|
||||||
log_requests=enable_log_requests,
|
log_requests=enable_log_requests,
|
||||||
log_stats=not disable_log_stats,
|
log_stats=not disable_log_stats,
|
||||||
|
aggregate_engine_logging=aggregate_engine_logging,
|
||||||
usage_context=usage_context,
|
usage_context=usage_context,
|
||||||
client_addresses=client_addresses,
|
client_addresses=client_addresses,
|
||||||
client_count=client_count,
|
client_count=client_count,
|
||||||
|
|||||||
@ -51,6 +51,7 @@ class LLMEngine:
|
|||||||
vllm_config: VllmConfig,
|
vllm_config: VllmConfig,
|
||||||
executor_class: type[Executor],
|
executor_class: type[Executor],
|
||||||
log_stats: bool,
|
log_stats: bool,
|
||||||
|
aggregate_engine_logging: bool = False,
|
||||||
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
|
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
|
||||||
stat_loggers: list[StatLoggerFactory] | None = None,
|
stat_loggers: list[StatLoggerFactory] | None = None,
|
||||||
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
|
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
|
||||||
@ -132,6 +133,7 @@ class LLMEngine:
|
|||||||
vllm_config=vllm_config,
|
vllm_config=vllm_config,
|
||||||
custom_stat_loggers=stat_loggers,
|
custom_stat_loggers=stat_loggers,
|
||||||
enable_default_loggers=log_stats,
|
enable_default_loggers=log_stats,
|
||||||
|
aggregate_engine_logging=aggregate_engine_logging,
|
||||||
)
|
)
|
||||||
self.logger_manager.log_engine_initialized()
|
self.logger_manager.log_engine_initialized()
|
||||||
|
|
||||||
|
|||||||
@ -24,7 +24,9 @@ from vllm.v1.spec_decode.metrics import SpecDecodingLogging, SpecDecodingProm
|
|||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
StatLoggerFactory = Callable[[VllmConfig, int], "StatLoggerBase"]
|
PerEngineStatLoggerFactory = Callable[[VllmConfig, int], "StatLoggerBase"]
|
||||||
|
AggregateStatLoggerFactory = type["AggregateStatLoggerBase"]
|
||||||
|
StatLoggerFactory = AggregateStatLoggerFactory | PerEngineStatLoggerFactory
|
||||||
|
|
||||||
|
|
||||||
class StatLoggerBase(ABC):
|
class StatLoggerBase(ABC):
|
||||||
@ -54,6 +56,14 @@ class StatLoggerBase(ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class AggregateStatLoggerBase(StatLoggerBase):
|
||||||
|
"""Abstract base class for loggers that
|
||||||
|
aggregate across multiple DP engines."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def __init__(self, vllm_config: VllmConfig, engine_indexes: list[int]): ...
|
||||||
|
|
||||||
|
|
||||||
class LoggingStatLogger(StatLoggerBase):
|
class LoggingStatLogger(StatLoggerBase):
|
||||||
def __init__(self, vllm_config: VllmConfig, engine_index: int = 0):
|
def __init__(self, vllm_config: VllmConfig, engine_index: int = 0):
|
||||||
self.engine_index = engine_index
|
self.engine_index = engine_index
|
||||||
@ -72,6 +82,8 @@ class LoggingStatLogger(StatLoggerBase):
|
|||||||
self.kv_connector_logging = KVConnectorLogging(kv_tranfer_config)
|
self.kv_connector_logging = KVConnectorLogging(kv_tranfer_config)
|
||||||
self.last_prompt_throughput: float = 0.0
|
self.last_prompt_throughput: float = 0.0
|
||||||
self.last_generation_throughput: float = 0.0
|
self.last_generation_throughput: float = 0.0
|
||||||
|
self.engine_is_idle = False
|
||||||
|
self.aggregated = False
|
||||||
|
|
||||||
def _reset(self, now):
|
def _reset(self, now):
|
||||||
self.last_log_time = now
|
self.last_log_time = now
|
||||||
@ -92,6 +104,10 @@ class LoggingStatLogger(StatLoggerBase):
|
|||||||
return 0.0
|
return 0.0
|
||||||
return float(tracked_stats / delta_time)
|
return float(tracked_stats / delta_time)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def log_prefix(self):
|
||||||
|
return "Engine {:03d}: ".format(self.engine_index)
|
||||||
|
|
||||||
def record(
|
def record(
|
||||||
self,
|
self,
|
||||||
scheduler_stats: SchedulerStats | None,
|
scheduler_stats: SchedulerStats | None,
|
||||||
@ -110,34 +126,37 @@ class LoggingStatLogger(StatLoggerBase):
|
|||||||
self.spec_decoding_logging.observe(scheduler_stats.spec_decoding_stats)
|
self.spec_decoding_logging.observe(scheduler_stats.spec_decoding_stats)
|
||||||
if kv_connector_stats := scheduler_stats.kv_connector_stats:
|
if kv_connector_stats := scheduler_stats.kv_connector_stats:
|
||||||
self.kv_connector_logging.observe(kv_connector_stats)
|
self.kv_connector_logging.observe(kv_connector_stats)
|
||||||
self.last_scheduler_stats = scheduler_stats
|
if not self.aggregated:
|
||||||
|
self.last_scheduler_stats = scheduler_stats
|
||||||
if mm_cache_stats:
|
if mm_cache_stats:
|
||||||
self.mm_caching_metrics.observe(mm_cache_stats)
|
self.mm_caching_metrics.observe(mm_cache_stats)
|
||||||
|
|
||||||
def log(self):
|
def _update_stats(self):
|
||||||
now = time.monotonic()
|
now = time.monotonic()
|
||||||
prompt_throughput = self._get_throughput(self.num_prompt_tokens, now)
|
prompt_throughput = self._get_throughput(self.num_prompt_tokens, now)
|
||||||
generation_throughput = self._get_throughput(self.num_generation_tokens, now)
|
generation_throughput = self._get_throughput(self.num_generation_tokens, now)
|
||||||
|
|
||||||
self._reset(now)
|
self._reset(now)
|
||||||
|
self.engine_is_idle = not any(
|
||||||
scheduler_stats = self.last_scheduler_stats
|
|
||||||
|
|
||||||
log_fn = logger.info
|
|
||||||
if not any(
|
|
||||||
(
|
(
|
||||||
prompt_throughput,
|
prompt_throughput,
|
||||||
generation_throughput,
|
generation_throughput,
|
||||||
self.last_prompt_throughput,
|
self.last_prompt_throughput,
|
||||||
self.last_generation_throughput,
|
self.last_generation_throughput,
|
||||||
)
|
)
|
||||||
):
|
)
|
||||||
# Avoid log noise on an idle production system
|
|
||||||
log_fn = logger.debug
|
|
||||||
self.last_generation_throughput = generation_throughput
|
self.last_generation_throughput = generation_throughput
|
||||||
self.last_prompt_throughput = prompt_throughput
|
self.last_prompt_throughput = prompt_throughput
|
||||||
|
|
||||||
|
def aggregate_scheduler_stats(self):
|
||||||
|
# noop for per engine loggers
|
||||||
|
return
|
||||||
|
|
||||||
|
def log(self):
|
||||||
|
self._update_stats()
|
||||||
|
self.aggregate_scheduler_stats()
|
||||||
|
# Avoid log noise on an idle production system
|
||||||
|
log_fn = logger.debug if self.engine_is_idle else logger.info
|
||||||
# Format and print output.
|
# Format and print output.
|
||||||
log_parts = [
|
log_parts = [
|
||||||
"Avg prompt throughput: %.1f tokens/s",
|
"Avg prompt throughput: %.1f tokens/s",
|
||||||
@ -148,11 +167,11 @@ class LoggingStatLogger(StatLoggerBase):
|
|||||||
"Prefix cache hit rate: %.1f%%",
|
"Prefix cache hit rate: %.1f%%",
|
||||||
]
|
]
|
||||||
log_args = [
|
log_args = [
|
||||||
prompt_throughput,
|
self.last_prompt_throughput,
|
||||||
generation_throughput,
|
self.last_generation_throughput,
|
||||||
scheduler_stats.num_running_reqs,
|
self.last_scheduler_stats.num_running_reqs,
|
||||||
scheduler_stats.num_waiting_reqs,
|
self.last_scheduler_stats.num_waiting_reqs,
|
||||||
scheduler_stats.kv_cache_usage * 100,
|
self.last_scheduler_stats.kv_cache_usage * 100,
|
||||||
self.prefix_caching_metrics.hit_rate * 100,
|
self.prefix_caching_metrics.hit_rate * 100,
|
||||||
]
|
]
|
||||||
if not self.mm_caching_metrics.empty:
|
if not self.mm_caching_metrics.empty:
|
||||||
@ -160,8 +179,7 @@ class LoggingStatLogger(StatLoggerBase):
|
|||||||
log_args.append(self.mm_caching_metrics.hit_rate * 100)
|
log_args.append(self.mm_caching_metrics.hit_rate * 100)
|
||||||
|
|
||||||
log_fn(
|
log_fn(
|
||||||
"Engine %03d: " + ", ".join(log_parts),
|
self.log_prefix + ", ".join(log_parts),
|
||||||
self.engine_index,
|
|
||||||
*log_args,
|
*log_args,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -178,7 +196,114 @@ class LoggingStatLogger(StatLoggerBase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class PrometheusStatLogger(StatLoggerBase):
|
class AggregatedLoggingStatLogger(LoggingStatLogger, AggregateStatLoggerBase):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vllm_config: VllmConfig,
|
||||||
|
engine_indexes: list[int],
|
||||||
|
):
|
||||||
|
self.engine_indexes = engine_indexes
|
||||||
|
self.last_scheduler_stats_dict: dict[int, SchedulerStats] = {
|
||||||
|
idx: SchedulerStats() for idx in self.engine_indexes
|
||||||
|
}
|
||||||
|
LoggingStatLogger.__init__(self, vllm_config, engine_index=-1)
|
||||||
|
self.aggregated = True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def log_prefix(self):
|
||||||
|
return "{} Engines Aggregated: ".format(len(self.engine_indexes))
|
||||||
|
|
||||||
|
def record(
|
||||||
|
self,
|
||||||
|
scheduler_stats: SchedulerStats | None,
|
||||||
|
iteration_stats: IterationStats | None,
|
||||||
|
mm_cache_stats: MultiModalCacheStats | None = None,
|
||||||
|
engine_idx: int = 0,
|
||||||
|
):
|
||||||
|
if engine_idx not in self.engine_indexes:
|
||||||
|
logger.warning("Unexpected engine_idx: %d", engine_idx)
|
||||||
|
return
|
||||||
|
LoggingStatLogger.record(
|
||||||
|
self,
|
||||||
|
scheduler_stats,
|
||||||
|
iteration_stats,
|
||||||
|
mm_cache_stats=mm_cache_stats,
|
||||||
|
engine_idx=engine_idx,
|
||||||
|
)
|
||||||
|
if scheduler_stats is not None:
|
||||||
|
self.last_scheduler_stats_dict[engine_idx] = scheduler_stats
|
||||||
|
|
||||||
|
def aggregate_scheduler_stats(self):
|
||||||
|
self.last_scheduler_stats = SchedulerStats()
|
||||||
|
for last_scheduler_stats in self.last_scheduler_stats_dict.values():
|
||||||
|
self.last_scheduler_stats.num_waiting_reqs += (
|
||||||
|
last_scheduler_stats.num_waiting_reqs
|
||||||
|
)
|
||||||
|
self.last_scheduler_stats.num_running_reqs += (
|
||||||
|
last_scheduler_stats.num_running_reqs
|
||||||
|
)
|
||||||
|
self.last_scheduler_stats.num_corrupted_reqs += (
|
||||||
|
last_scheduler_stats.num_corrupted_reqs
|
||||||
|
)
|
||||||
|
self.last_scheduler_stats.kv_cache_usage += (
|
||||||
|
last_scheduler_stats.kv_cache_usage
|
||||||
|
)
|
||||||
|
self.last_scheduler_stats.kv_cache_usage /= len(self.last_scheduler_stats_dict)
|
||||||
|
|
||||||
|
def log(self):
|
||||||
|
LoggingStatLogger.log(self)
|
||||||
|
|
||||||
|
def log_engine_initialized(self):
|
||||||
|
if self.vllm_config.cache_config.num_gpu_blocks:
|
||||||
|
logger.info(
|
||||||
|
"%d Engines: vllm cache_config_info with initialization "
|
||||||
|
"after num_gpu_blocks is: %d",
|
||||||
|
len(self.engine_indexes),
|
||||||
|
self.vllm_config.cache_config.num_gpu_blocks,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PerEngineStatLoggerAdapter(AggregateStatLoggerBase):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vllm_config: VllmConfig,
|
||||||
|
engine_indexes: list[int],
|
||||||
|
per_engine_stat_logger_factory: PerEngineStatLoggerFactory,
|
||||||
|
) -> None:
|
||||||
|
self.per_engine_stat_loggers = {}
|
||||||
|
self.engine_indexes = engine_indexes
|
||||||
|
for engine_index in engine_indexes:
|
||||||
|
self.per_engine_stat_loggers[engine_index] = per_engine_stat_logger_factory(
|
||||||
|
vllm_config, engine_index
|
||||||
|
)
|
||||||
|
|
||||||
|
def record(
|
||||||
|
self,
|
||||||
|
scheduler_stats: SchedulerStats | None,
|
||||||
|
iteration_stats: IterationStats | None,
|
||||||
|
mm_cache_stats: MultiModalCacheStats | None = None,
|
||||||
|
engine_idx: int = 0,
|
||||||
|
):
|
||||||
|
if engine_idx not in self.per_engine_stat_loggers:
|
||||||
|
logger.warning("Unexpected engine_idx: %d", engine_idx)
|
||||||
|
return
|
||||||
|
self.per_engine_stat_loggers[engine_idx].record(
|
||||||
|
scheduler_stats,
|
||||||
|
iteration_stats,
|
||||||
|
mm_cache_stats=mm_cache_stats,
|
||||||
|
engine_idx=engine_idx,
|
||||||
|
)
|
||||||
|
|
||||||
|
def log(self):
|
||||||
|
for per_engine_stat_logger in self.per_engine_stat_loggers.values():
|
||||||
|
per_engine_stat_logger.log()
|
||||||
|
|
||||||
|
def log_engine_initialized(self):
|
||||||
|
for per_engine_stat_logger in self.per_engine_stat_loggers.values():
|
||||||
|
per_engine_stat_logger.log_engine_initialized()
|
||||||
|
|
||||||
|
|
||||||
|
class PrometheusStatLogger(AggregateStatLoggerBase):
|
||||||
_gauge_cls = Gauge
|
_gauge_cls = Gauge
|
||||||
_counter_cls = Counter
|
_counter_cls = Counter
|
||||||
_histogram_cls = Histogram
|
_histogram_cls = Histogram
|
||||||
@ -189,6 +314,7 @@ class PrometheusStatLogger(StatLoggerBase):
|
|||||||
):
|
):
|
||||||
if engine_indexes is None:
|
if engine_indexes is None:
|
||||||
engine_indexes = [0]
|
engine_indexes = [0]
|
||||||
|
|
||||||
self.engine_indexes = engine_indexes
|
self.engine_indexes = engine_indexes
|
||||||
|
|
||||||
unregister_vllm_metrics()
|
unregister_vllm_metrics()
|
||||||
@ -880,14 +1006,14 @@ class StatLoggerManager:
|
|||||||
engine_idxs: list[int] | None = None,
|
engine_idxs: list[int] | None = None,
|
||||||
custom_stat_loggers: list[StatLoggerFactory] | None = None,
|
custom_stat_loggers: list[StatLoggerFactory] | None = None,
|
||||||
enable_default_loggers: bool = True,
|
enable_default_loggers: bool = True,
|
||||||
|
aggregate_engine_logging: bool = False,
|
||||||
client_count: int = 1,
|
client_count: int = 1,
|
||||||
):
|
):
|
||||||
self.engine_idxs = engine_idxs if engine_idxs else [0]
|
self.engine_indexes = engine_idxs if engine_idxs else [0]
|
||||||
|
self.stat_loggers: list[AggregateStatLoggerBase] = []
|
||||||
factories: list[StatLoggerFactory] = []
|
stat_logger_factories: list[StatLoggerFactory] = []
|
||||||
if custom_stat_loggers is not None:
|
if custom_stat_loggers is not None:
|
||||||
factories.extend(custom_stat_loggers)
|
stat_logger_factories.extend(custom_stat_loggers)
|
||||||
|
|
||||||
if enable_default_loggers and logger.isEnabledFor(logging.INFO):
|
if enable_default_loggers and logger.isEnabledFor(logging.INFO):
|
||||||
if client_count > 1:
|
if client_count > 1:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
@ -895,27 +1021,35 @@ class StatLoggerManager:
|
|||||||
"disabling stats logging to avoid incomplete stats."
|
"disabling stats logging to avoid incomplete stats."
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
factories.append(LoggingStatLogger)
|
default_logger_factory = (
|
||||||
|
AggregatedLoggingStatLogger
|
||||||
# engine_idx: StatLogger
|
if aggregate_engine_logging
|
||||||
self.per_engine_logger_dict: dict[int, list[StatLoggerBase]] = {}
|
else LoggingStatLogger
|
||||||
prometheus_factory = PrometheusStatLogger
|
)
|
||||||
for engine_idx in self.engine_idxs:
|
stat_logger_factories.append(default_logger_factory)
|
||||||
loggers: list[StatLoggerBase] = []
|
custom_prometheus_logger: bool = False
|
||||||
for logger_factory in factories:
|
for stat_logger_factory in stat_logger_factories:
|
||||||
# If we get a custom prometheus logger, use that
|
if isinstance(stat_logger_factory, type) and issubclass(
|
||||||
# instead. This is typically used for the ray case.
|
stat_logger_factory, AggregateStatLoggerBase
|
||||||
if isinstance(logger_factory, type) and issubclass(
|
):
|
||||||
logger_factory, PrometheusStatLogger
|
global_stat_logger = stat_logger_factory(
|
||||||
):
|
vllm_config=vllm_config,
|
||||||
prometheus_factory = logger_factory
|
engine_indexes=self.engine_indexes,
|
||||||
continue
|
)
|
||||||
loggers.append(logger_factory(vllm_config, engine_idx)) # type: ignore
|
if isinstance(global_stat_logger, PrometheusStatLogger):
|
||||||
self.per_engine_logger_dict[engine_idx] = loggers
|
custom_prometheus_logger = True
|
||||||
|
else:
|
||||||
# For Prometheus, need to share the metrics between EngineCores.
|
# per engine logger
|
||||||
# Each EngineCore's metrics are expressed as a unique label.
|
global_stat_logger = PerEngineStatLoggerAdapter(
|
||||||
self.prometheus_logger = prometheus_factory(vllm_config, engine_idxs)
|
vllm_config=vllm_config,
|
||||||
|
engine_indexes=self.engine_indexes,
|
||||||
|
per_engine_stat_logger_factory=stat_logger_factory, # type: ignore[arg-type]
|
||||||
|
)
|
||||||
|
self.stat_loggers.append(global_stat_logger)
|
||||||
|
if not custom_prometheus_logger:
|
||||||
|
self.stat_loggers.append(
|
||||||
|
PrometheusStatLogger(vllm_config, self.engine_indexes)
|
||||||
|
)
|
||||||
|
|
||||||
def record(
|
def record(
|
||||||
self,
|
self,
|
||||||
@ -926,9 +1060,7 @@ class StatLoggerManager:
|
|||||||
):
|
):
|
||||||
if engine_idx is None:
|
if engine_idx is None:
|
||||||
engine_idx = 0
|
engine_idx = 0
|
||||||
|
for logger in self.stat_loggers:
|
||||||
per_engine_loggers = self.per_engine_logger_dict[engine_idx]
|
|
||||||
for logger in per_engine_loggers:
|
|
||||||
logger.record(
|
logger.record(
|
||||||
scheduler_stats,
|
scheduler_stats,
|
||||||
iteration_stats,
|
iteration_stats,
|
||||||
@ -936,21 +1068,10 @@ class StatLoggerManager:
|
|||||||
engine_idx=engine_idx,
|
engine_idx=engine_idx,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.prometheus_logger.record(
|
|
||||||
scheduler_stats,
|
|
||||||
iteration_stats,
|
|
||||||
mm_cache_stats=mm_cache_stats,
|
|
||||||
engine_idx=engine_idx,
|
|
||||||
)
|
|
||||||
|
|
||||||
def log(self):
|
def log(self):
|
||||||
for per_engine_loggers in self.per_engine_logger_dict.values():
|
for logger in self.stat_loggers:
|
||||||
for logger in per_engine_loggers:
|
logger.log()
|
||||||
logger.log()
|
|
||||||
|
|
||||||
def log_engine_initialized(self):
|
def log_engine_initialized(self):
|
||||||
self.prometheus_logger.log_engine_initialized()
|
for agg_logger in self.stat_loggers:
|
||||||
|
agg_logger.log_engine_initialized()
|
||||||
for per_engine_loggers in self.per_engine_logger_dict.values():
|
|
||||||
for logger in per_engine_loggers:
|
|
||||||
logger.log_engine_initialized()
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user