mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-07 11:07:03 +08:00
[Misc][DP] support customized global logger for dp
Signed-off-by: Lu Fang <fanglu@fb.com> fix the test Signed-off-by: Lu Fang <fanglu@fb.com> address comments Signed-off-by: Lu Fang <fanglu@fb.com>
This commit is contained in:
parent
064cac7bb7
commit
a46e279909
@ -1,12 +1,15 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import asyncio
|
||||
import threading
|
||||
from typing import Optional
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.v1.metrics.loggers import GlobalStatLogger, LoggingStatLogger
|
||||
|
||||
"""
|
||||
To run this example, run the following commands simultaneously with
|
||||
@ -22,37 +25,72 @@ send a request to the instance with DP rank 1.
|
||||
"""
|
||||
|
||||
|
||||
def _do_background_logging(engine, interval, stop_event):
|
||||
try:
|
||||
while not stop_event.is_set():
|
||||
asyncio.run(engine.do_log_stats())
|
||||
stop_event.wait(interval)
|
||||
except Exception as e:
|
||||
print(f"vLLM background logging shutdown: {e}")
|
||||
pass
|
||||
|
||||
|
||||
async def main():
|
||||
engine_args = AsyncEngineArgs(
|
||||
model="ibm-research/PowerMoE-3b",
|
||||
data_parallel_size=2,
|
||||
tensor_parallel_size=1,
|
||||
dtype="auto",
|
||||
max_model_len=2048,
|
||||
data_parallel_address="127.0.0.1",
|
||||
data_parallel_rpc_port=62300,
|
||||
data_parallel_size_local=1,
|
||||
enforce_eager=True,
|
||||
enable_log_requests=True,
|
||||
disable_custom_all_reduce=True,
|
||||
)
|
||||
|
||||
engine_client = AsyncLLMEngine.from_engine_args(engine_args)
|
||||
def per_engine_logger_factory(config: VllmConfig, rank: int) -> LoggingStatLogger:
|
||||
return LoggingStatLogger(config, rank)
|
||||
|
||||
def global_logger_factory(
|
||||
config: VllmConfig, engine_idxs: Optional[list[int]]
|
||||
) -> GlobalStatLogger:
|
||||
return GlobalStatLogger(config, engine_indexes=engine_idxs)
|
||||
|
||||
engine_client = AsyncLLMEngine.from_engine_args(
|
||||
engine_args,
|
||||
stat_loggers=[per_engine_logger_factory],
|
||||
stat_logger_global=global_logger_factory,
|
||||
)
|
||||
stop_logging_event = threading.Event()
|
||||
logging_thread = threading.Thread(
|
||||
target=_do_background_logging,
|
||||
args=(engine_client, 5, stop_logging_event),
|
||||
daemon=True,
|
||||
)
|
||||
logging_thread.start()
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0.7,
|
||||
top_p=0.9,
|
||||
max_tokens=100,
|
||||
)
|
||||
num_prompts = 10
|
||||
for i in range(num_prompts):
|
||||
prompt = "Who won the 2004 World Series?"
|
||||
final_output: Optional[RequestOutput] = None
|
||||
async for output in engine_client.generate(
|
||||
prompt=prompt,
|
||||
sampling_params=sampling_params,
|
||||
request_id=f"abcdef-{i}",
|
||||
data_parallel_rank=1,
|
||||
):
|
||||
final_output = output
|
||||
if final_output:
|
||||
print(final_output.outputs[0].text)
|
||||
|
||||
prompt = "Who won the 2004 World Series?"
|
||||
final_output: Optional[RequestOutput] = None
|
||||
async for output in engine_client.generate(
|
||||
prompt=prompt,
|
||||
sampling_params=sampling_params,
|
||||
request_id="abcdef",
|
||||
data_parallel_rank=1,
|
||||
):
|
||||
final_output = output
|
||||
if final_output:
|
||||
print(final_output.outputs[0].text)
|
||||
stop_logging_event.set()
|
||||
logging_thread.join()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -18,7 +18,7 @@ from vllm.platforms import current_platform
|
||||
from vllm.sampling_params import RequestOutputKind
|
||||
from vllm.utils import set_default_torch_num_threads
|
||||
from vllm.v1.engine.async_llm import AsyncLLM
|
||||
from vllm.v1.metrics.loggers import LoggingStatLogger
|
||||
from vllm.v1.metrics.loggers import GlobalStatLogger, LoggingStatLogger
|
||||
|
||||
if not current_platform.is_cuda():
|
||||
pytest.skip(reason="V1 currently only supported on CUDA.",
|
||||
@ -389,6 +389,15 @@ class MockLoggingStatLogger(LoggingStatLogger):
|
||||
self.log = MagicMock()
|
||||
|
||||
|
||||
class MockGlobalStatLogger(GlobalStatLogger):
|
||||
|
||||
def __init__(self,
|
||||
vllm_config: VllmConfig,
|
||||
engine_indexes: Optional[list[int]] = None):
|
||||
super().__init__(vllm_config, engine_indexes)
|
||||
self.log = MagicMock()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_customize_loggers(monkeypatch):
|
||||
"""Test that we can customize the loggers.
|
||||
@ -415,6 +424,36 @@ async def test_customize_loggers(monkeypatch):
|
||||
stat_loggers[0][0].log.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_customize_global_loggers(monkeypatch):
|
||||
"""Test that we can customize the loggers.
|
||||
If a customized logger is provided at the init, it should
|
||||
be added to the default loggers.
|
||||
"""
|
||||
|
||||
with monkeypatch.context() as m, ExitStack() as after:
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
|
||||
with set_default_torch_num_threads(1):
|
||||
engine = AsyncLLM.from_engine_args(
|
||||
TEXT_ENGINE_ARGS,
|
||||
stat_loggers=[MockLoggingStatLogger],
|
||||
stat_logger_global=MockGlobalStatLogger,
|
||||
)
|
||||
after.callback(engine.shutdown)
|
||||
|
||||
await engine.do_log_stats()
|
||||
|
||||
stat_loggers = engine.logger_manager.per_engine_logger_dict
|
||||
assert len(stat_loggers) == 1
|
||||
assert len(
|
||||
stat_loggers[0]) == 2 # LoggingStatLogger + MockLoggingStatLogger
|
||||
global_logger = engine.logger_manager.global_logger
|
||||
assert global_logger is not None
|
||||
global_logger.log.assert_called_once()
|
||||
stat_loggers[0][0].log.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio(scope="module")
|
||||
async def test_dp_rank_argument(monkeypatch: pytest.MonkeyPatch):
|
||||
with monkeypatch.context() as m, ExitStack() as after:
|
||||
|
||||
@ -42,7 +42,8 @@ from vllm.v1.engine.output_processor import (OutputProcessor,
|
||||
from vllm.v1.engine.parallel_sampling import ParentRequest
|
||||
from vllm.v1.engine.processor import Processor
|
||||
from vllm.v1.executor.abstract import Executor
|
||||
from vllm.v1.metrics.loggers import StatLoggerFactory, StatLoggerManager
|
||||
from vllm.v1.metrics.loggers import (GlobalStatLoggerFactory,
|
||||
StatLoggerFactory, StatLoggerManager)
|
||||
from vllm.v1.metrics.prometheus import shutdown_prometheus
|
||||
from vllm.v1.metrics.stats import IterationStats
|
||||
|
||||
@ -62,6 +63,7 @@ class AsyncLLM(EngineClient):
|
||||
log_requests: bool = True,
|
||||
start_engine_loop: bool = True,
|
||||
stat_loggers: Optional[list[StatLoggerFactory]] = None,
|
||||
stat_logger_global: Optional[GlobalStatLoggerFactory] = None,
|
||||
client_addresses: Optional[dict[str, str]] = None,
|
||||
client_count: int = 1,
|
||||
client_index: int = 0,
|
||||
@ -101,8 +103,11 @@ class AsyncLLM(EngineClient):
|
||||
self.observability_config = vllm_config.observability_config
|
||||
self.log_requests = log_requests
|
||||
|
||||
self.log_stats = log_stats or (stat_loggers is not None)
|
||||
if not log_stats and stat_loggers is not None:
|
||||
self.log_stats = log_stats or (stat_loggers
|
||||
is not None) or (stat_logger_global
|
||||
is not None)
|
||||
if (not log_stats and stat_loggers is not None
|
||||
or stat_logger_global is not None):
|
||||
logger.info(
|
||||
"AsyncLLM created with log_stats=False and non-empty custom "
|
||||
"logger list; enabling logging without default stat loggers")
|
||||
@ -147,6 +152,7 @@ class AsyncLLM(EngineClient):
|
||||
vllm_config=vllm_config,
|
||||
engine_idxs=self.engine_core.engine_ranks_managed,
|
||||
custom_stat_loggers=stat_loggers,
|
||||
custom_stat_logger_global=stat_logger_global,
|
||||
enable_default_loggers=log_stats,
|
||||
client_count=client_count,
|
||||
)
|
||||
@ -189,6 +195,7 @@ class AsyncLLM(EngineClient):
|
||||
start_engine_loop: bool = True,
|
||||
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
|
||||
stat_loggers: Optional[list[StatLoggerFactory]] = None,
|
||||
stat_logger_global: Optional[GlobalStatLoggerFactory] = None,
|
||||
enable_log_requests: bool = False,
|
||||
disable_log_stats: bool = False,
|
||||
client_addresses: Optional[dict[str, str]] = None,
|
||||
@ -209,6 +216,7 @@ class AsyncLLM(EngineClient):
|
||||
executor_class=Executor.get_class(vllm_config),
|
||||
start_engine_loop=start_engine_loop,
|
||||
stat_loggers=stat_loggers,
|
||||
stat_logger_global=stat_logger_global,
|
||||
log_requests=enable_log_requests,
|
||||
log_stats=not disable_log_stats,
|
||||
usage_context=usage_context,
|
||||
@ -224,6 +232,7 @@ class AsyncLLM(EngineClient):
|
||||
start_engine_loop: bool = True,
|
||||
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
|
||||
stat_loggers: Optional[list[StatLoggerFactory]] = None,
|
||||
stat_logger_global: Optional[GlobalStatLoggerFactory] = None,
|
||||
) -> "AsyncLLM":
|
||||
"""Create an AsyncLLM from the EngineArgs."""
|
||||
|
||||
@ -240,6 +249,7 @@ class AsyncLLM(EngineClient):
|
||||
start_engine_loop=start_engine_loop,
|
||||
usage_context=usage_context,
|
||||
stat_loggers=stat_loggers,
|
||||
stat_logger_global=stat_logger_global,
|
||||
)
|
||||
|
||||
def __del__(self):
|
||||
|
||||
@ -19,6 +19,8 @@ from vllm.v1.spec_decode.metrics import SpecDecodingLogging, SpecDecodingProm
|
||||
logger = init_logger(__name__)
|
||||
|
||||
StatLoggerFactory = Callable[[VllmConfig, int], "StatLoggerBase"]
|
||||
GlobalStatLoggerFactory = Callable[[VllmConfig, list[int]],
|
||||
"GlobalStatLoggerBase"]
|
||||
|
||||
|
||||
class StatLoggerBase(ABC):
|
||||
@ -48,6 +50,15 @@ class StatLoggerBase(ABC):
|
||||
pass
|
||||
|
||||
|
||||
class GlobalStatLoggerBase(StatLoggerBase):
|
||||
"""Interface for logging metrics for multiple engines."""
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, vllm_config: VllmConfig,
|
||||
engine_indexes: Optional[list[int]]):
|
||||
...
|
||||
|
||||
|
||||
class LoggingStatLogger(StatLoggerBase):
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig, engine_index: int = 0):
|
||||
@ -145,6 +156,63 @@ class LoggingStatLogger(StatLoggerBase):
|
||||
self.vllm_config.cache_config.num_gpu_blocks)
|
||||
|
||||
|
||||
class GlobalStatLogger(LoggingStatLogger, GlobalStatLoggerBase):
|
||||
|
||||
def __init__(self,
|
||||
vllm_config: VllmConfig,
|
||||
engine_indexes: Optional[list[int]] = None):
|
||||
super().__init__(vllm_config, -1)
|
||||
if engine_indexes is None:
|
||||
engine_indexes = [0]
|
||||
self.engine_index = -1
|
||||
self.engine_indexes = engine_indexes
|
||||
self.vllm_config = vllm_config
|
||||
|
||||
def log(self):
|
||||
now = time.monotonic()
|
||||
prompt_throughput = self._get_throughput(self.num_prompt_tokens, now)
|
||||
generation_throughput = self._get_throughput(
|
||||
self.num_generation_tokens, now)
|
||||
|
||||
self._reset(now)
|
||||
|
||||
scheduler_stats = self.last_scheduler_stats
|
||||
|
||||
log_fn = logger.info
|
||||
if not any(
|
||||
(prompt_throughput, generation_throughput,
|
||||
self.last_prompt_throughput, self.last_generation_throughput)):
|
||||
# Avoid log noise on an idle production system
|
||||
log_fn = logger.debug
|
||||
self.last_generation_throughput = generation_throughput
|
||||
self.last_prompt_throughput = prompt_throughput
|
||||
|
||||
# Format and print output.
|
||||
log_fn(
|
||||
"%s Engines Aggregated: "
|
||||
"Avg prompt throughput: %.1f tokens/s, "
|
||||
"Avg generation throughput: %.1f tokens/s, "
|
||||
"Running: %d reqs, Waiting: %d reqs, "
|
||||
"GPU KV cache usage: %.1f%%, "
|
||||
"Prefix cache hit rate: %.1f%%",
|
||||
len(self.engine_indexes),
|
||||
prompt_throughput,
|
||||
generation_throughput,
|
||||
scheduler_stats.num_running_reqs,
|
||||
scheduler_stats.num_waiting_reqs,
|
||||
scheduler_stats.kv_cache_usage * 100,
|
||||
self.prefix_caching_metrics.hit_rate * 100,
|
||||
)
|
||||
self.spec_decoding_logging.log(log_fn=log_fn)
|
||||
|
||||
def log_engine_initialized(self):
|
||||
if self.vllm_config.cache_config.num_gpu_blocks:
|
||||
logger.info(
|
||||
"%d Engines: vllm cache_config_info with initialization "
|
||||
"after num_gpu_blocks is: %d", len(self.engine_indexes),
|
||||
self.vllm_config.cache_config.num_gpu_blocks)
|
||||
|
||||
|
||||
class PrometheusStatLogger(StatLoggerBase):
|
||||
_gauge_cls = prometheus_client.Gauge
|
||||
_counter_cls = prometheus_client.Counter
|
||||
@ -655,6 +723,7 @@ class StatLoggerManager:
|
||||
vllm_config: VllmConfig,
|
||||
engine_idxs: Optional[list[int]] = None,
|
||||
custom_stat_loggers: Optional[list[StatLoggerFactory]] = None,
|
||||
custom_stat_logger_global: Optional[GlobalStatLoggerFactory] = None,
|
||||
enable_default_loggers: bool = True,
|
||||
client_count: int = 1,
|
||||
):
|
||||
@ -688,9 +757,14 @@ class StatLoggerManager:
|
||||
engine_idx)) # type: ignore
|
||||
self.per_engine_logger_dict[engine_idx] = loggers
|
||||
|
||||
# For Prometheus, need to share the metrics between EngineCores.
|
||||
# For Prometheus or custom global logger,
|
||||
# need to share the metrics between EngineCores.
|
||||
# Each EngineCore's metrics are expressed as a unique label.
|
||||
self.prometheus_logger = prometheus_factory(vllm_config, engine_idxs)
|
||||
self.global_logger: Optional[StatLoggerBase] = None
|
||||
if custom_stat_logger_global is not None:
|
||||
self.global_logger = custom_stat_logger_global(
|
||||
vllm_config, self.engine_idxs)
|
||||
|
||||
def record(
|
||||
self,
|
||||
@ -707,15 +781,21 @@ class StatLoggerManager:
|
||||
|
||||
self.prometheus_logger.record(scheduler_stats, iteration_stats,
|
||||
engine_idx)
|
||||
if self.global_logger is not None:
|
||||
self.global_logger.record(scheduler_stats, iteration_stats,
|
||||
engine_idx)
|
||||
|
||||
def log(self):
|
||||
for per_engine_loggers in self.per_engine_logger_dict.values():
|
||||
for logger in per_engine_loggers:
|
||||
logger.log()
|
||||
if self.global_logger is not None:
|
||||
self.global_logger.log()
|
||||
|
||||
def log_engine_initialized(self):
|
||||
self.prometheus_logger.log_engine_initialized()
|
||||
|
||||
if self.global_logger is not None:
|
||||
self.global_logger.log_engine_initialized()
|
||||
for per_engine_loggers in self.per_engine_logger_dict.values():
|
||||
for logger in per_engine_loggers:
|
||||
logger.log_engine_initialized()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user