mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 04:05:01 +08:00
[V1][Metrics][Plugin] Add plugin support for custom StatLoggerBase implementations (#22456)
Signed-off-by: tovam <tovam@pliops.com>
This commit is contained in:
parent
c2bba69065
commit
83e760c57d
@ -1020,6 +1020,11 @@ steps:
|
|||||||
- pytest -v -s plugins_tests/test_io_processor_plugins.py
|
- pytest -v -s plugins_tests/test_io_processor_plugins.py
|
||||||
- pip uninstall prithvi_io_processor_plugin -y
|
- pip uninstall prithvi_io_processor_plugin -y
|
||||||
# end io_processor plugins test
|
# end io_processor plugins test
|
||||||
|
# begin stat_logger plugins test
|
||||||
|
- pip install -e ./plugins/vllm_add_dummy_stat_logger
|
||||||
|
- pytest -v -s plugins_tests/test_stats_logger_plugins.py
|
||||||
|
- pip uninstall dummy_stat_logger -y
|
||||||
|
# end stat_logger plugins test
|
||||||
# other tests continue here:
|
# other tests continue here:
|
||||||
- pytest -v -s plugins_tests/test_scheduler_plugins.py
|
- pytest -v -s plugins_tests/test_scheduler_plugins.py
|
||||||
- pip install -e ./plugins/vllm_add_dummy_model
|
- pip install -e ./plugins/vllm_add_dummy_model
|
||||||
|
|||||||
@ -41,7 +41,7 @@ Every plugin has three parts:
|
|||||||
|
|
||||||
1. **Plugin group**: The name of the entry point group. vLLM uses the entry point group `vllm.general_plugins` to register general plugins. This is the key of `entry_points` in the `setup.py` file. Always use `vllm.general_plugins` for vLLM's general plugins.
|
1. **Plugin group**: The name of the entry point group. vLLM uses the entry point group `vllm.general_plugins` to register general plugins. This is the key of `entry_points` in the `setup.py` file. Always use `vllm.general_plugins` for vLLM's general plugins.
|
||||||
2. **Plugin name**: The name of the plugin. This is the value in the dictionary of the `entry_points` dictionary. In the example above, the plugin name is `register_dummy_model`. Plugins can be filtered by their names using the `VLLM_PLUGINS` environment variable. To load only a specific plugin, set `VLLM_PLUGINS` to the plugin name.
|
2. **Plugin name**: The name of the plugin. This is the value in the dictionary of the `entry_points` dictionary. In the example above, the plugin name is `register_dummy_model`. Plugins can be filtered by their names using the `VLLM_PLUGINS` environment variable. To load only a specific plugin, set `VLLM_PLUGINS` to the plugin name.
|
||||||
3. **Plugin value**: The fully qualified name of the function to register in the plugin system. In the example above, the plugin value is `vllm_add_dummy_model:register`, which refers to a function named `register` in the `vllm_add_dummy_model` module.
|
3. **Plugin value**: The fully qualified name of the function or module to register in the plugin system. In the example above, the plugin value is `vllm_add_dummy_model:register`, which refers to a function named `register` in the `vllm_add_dummy_model` module.
|
||||||
|
|
||||||
## Types of supported plugins
|
## Types of supported plugins
|
||||||
|
|
||||||
@ -51,6 +51,8 @@ Every plugin has three parts:
|
|||||||
|
|
||||||
- **IO Processor plugins** (with group name `vllm.io_processor_plugins`): The primary use case for these plugins is to register custom pre/post processing of the model prompt and model output for pooling models. The plugin function returns the IOProcessor's class fully qualified name.
|
- **IO Processor plugins** (with group name `vllm.io_processor_plugins`): The primary use case for these plugins is to register custom pre/post processing of the model prompt and model output for pooling models. The plugin function returns the IOProcessor's class fully qualified name.
|
||||||
|
|
||||||
|
- **Stat logger plugins** (with group name `vllm.stat_logger_plugins`): The primary use case for these plugins is to register custom, out-of-the-tree loggers into vLLM. The entry point should be a class that subclasses StatLoggerBase.
|
||||||
|
|
||||||
## Guidelines for Writing Plugins
|
## Guidelines for Writing Plugins
|
||||||
|
|
||||||
- **Being re-entrant**: The function specified in the entry point should be re-entrant, meaning it can be called multiple times without causing issues. This is necessary because the function might be called multiple times in some processes.
|
- **Being re-entrant**: The function specified in the entry point should be re-entrant, meaning it can be called multiple times without causing issues. This is necessary because the function might be called multiple times in some processes.
|
||||||
|
|||||||
@ -0,0 +1,29 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
from vllm.v1.metrics.loggers import StatLoggerBase
|
||||||
|
|
||||||
|
|
||||||
|
class DummyStatLogger(StatLoggerBase):
|
||||||
|
"""
|
||||||
|
A dummy stat logger for testing purposes.
|
||||||
|
Implements the minimal interface expected by StatLoggerManager.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, vllm_config, engine_idx=0):
|
||||||
|
self.vllm_config = vllm_config
|
||||||
|
self.engine_idx = engine_idx
|
||||||
|
self.recorded = []
|
||||||
|
self.logged = False
|
||||||
|
self.engine_initialized = False
|
||||||
|
|
||||||
|
def record(self, scheduler_stats, iteration_stats, mm_cache_stats, engine_idx):
|
||||||
|
self.recorded.append(
|
||||||
|
(scheduler_stats, iteration_stats, mm_cache_stats, engine_idx)
|
||||||
|
)
|
||||||
|
|
||||||
|
def log(self):
|
||||||
|
self.logged = True
|
||||||
|
|
||||||
|
def log_engine_initialized(self):
|
||||||
|
self.engine_initialized = True
|
||||||
15
tests/plugins/vllm_add_dummy_stat_logger/setup.py
Normal file
15
tests/plugins/vllm_add_dummy_stat_logger/setup.py
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
from setuptools import setup
|
||||||
|
|
||||||
|
setup(
|
||||||
|
name="dummy_stat_logger",
|
||||||
|
version="0.1",
|
||||||
|
packages=["dummy_stat_logger"],
|
||||||
|
entry_points={
|
||||||
|
"vllm.stat_logger_plugins": [
|
||||||
|
"dummy_stat_logger = dummy_stat_logger.dummy_stat_logger:DummyStatLogger" # noqa
|
||||||
|
]
|
||||||
|
},
|
||||||
|
)
|
||||||
76
tests/plugins_tests/test_stats_logger_plugins.py
Normal file
76
tests/plugins_tests/test_stats_logger_plugins.py
Normal file
@ -0,0 +1,76 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from dummy_stat_logger.dummy_stat_logger import DummyStatLogger
|
||||||
|
|
||||||
|
from vllm.config import VllmConfig
|
||||||
|
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||||
|
from vllm.v1.engine.async_llm import AsyncLLM
|
||||||
|
from vllm.v1.metrics.loggers import load_stat_logger_plugin_factories
|
||||||
|
|
||||||
|
|
||||||
|
def test_stat_logger_plugin_is_discovered(monkeypatch: pytest.MonkeyPatch):
|
||||||
|
with monkeypatch.context() as m:
|
||||||
|
m.setenv("VLLM_PLUGINS", "dummy_stat_logger")
|
||||||
|
|
||||||
|
factories = load_stat_logger_plugin_factories()
|
||||||
|
assert len(factories) == 1, f"Expected 1 factory, got {len(factories)}"
|
||||||
|
assert factories[0] is DummyStatLogger, (
|
||||||
|
f"Expected DummyStatLogger class, got {factories[0]}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# instantiate and confirm the right type
|
||||||
|
vllm_config = VllmConfig()
|
||||||
|
instance = factories[0](vllm_config)
|
||||||
|
assert isinstance(instance, DummyStatLogger)
|
||||||
|
|
||||||
|
|
||||||
|
def test_no_plugins_loaded_if_env_empty(monkeypatch: pytest.MonkeyPatch):
|
||||||
|
with monkeypatch.context() as m:
|
||||||
|
m.setenv("VLLM_PLUGINS", "")
|
||||||
|
|
||||||
|
factories = load_stat_logger_plugin_factories()
|
||||||
|
assert factories == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_invalid_stat_logger_plugin_raises(monkeypatch: pytest.MonkeyPatch):
|
||||||
|
def fake_plugin_loader(group: str):
|
||||||
|
assert group == "vllm.stat_logger_plugins"
|
||||||
|
return {"bad": object()}
|
||||||
|
|
||||||
|
with monkeypatch.context() as m:
|
||||||
|
m.setattr(
|
||||||
|
"vllm.v1.metrics.loggers.load_plugins_by_group",
|
||||||
|
fake_plugin_loader,
|
||||||
|
)
|
||||||
|
with pytest.raises(
|
||||||
|
TypeError,
|
||||||
|
match="Stat logger plugin 'bad' must be a subclass of StatLoggerBase",
|
||||||
|
):
|
||||||
|
load_stat_logger_plugin_factories()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_stat_logger_plugin_integration_with_engine(
|
||||||
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
):
|
||||||
|
with monkeypatch.context() as m:
|
||||||
|
m.setenv("VLLM_PLUGINS", "dummy_stat_logger")
|
||||||
|
|
||||||
|
engine_args = AsyncEngineArgs(
|
||||||
|
model="facebook/opt-125m",
|
||||||
|
enforce_eager=True, # reduce test time
|
||||||
|
disable_log_stats=True, # disable default loggers
|
||||||
|
)
|
||||||
|
|
||||||
|
engine = AsyncLLM.from_engine_args(engine_args=engine_args)
|
||||||
|
|
||||||
|
assert len(engine.logger_manager.stat_loggers) == 2
|
||||||
|
assert len(engine.logger_manager.stat_loggers[0].per_engine_stat_loggers) == 1
|
||||||
|
assert isinstance(
|
||||||
|
engine.logger_manager.stat_loggers[0].per_engine_stat_loggers[0],
|
||||||
|
DummyStatLogger,
|
||||||
|
)
|
||||||
|
|
||||||
|
engine.shutdown()
|
||||||
@ -4,33 +4,13 @@ import copy
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from tests.plugins.vllm_add_dummy_stat_logger.dummy_stat_logger.dummy_stat_logger import ( # noqa E501
|
||||||
|
DummyStatLogger,
|
||||||
|
)
|
||||||
from vllm.v1.engine.async_llm import AsyncEngineArgs, AsyncLLM
|
from vllm.v1.engine.async_llm import AsyncEngineArgs, AsyncLLM
|
||||||
from vllm.v1.metrics.ray_wrappers import RayPrometheusStatLogger
|
from vllm.v1.metrics.ray_wrappers import RayPrometheusStatLogger
|
||||||
|
|
||||||
|
|
||||||
class DummyStatLogger:
|
|
||||||
"""
|
|
||||||
A dummy stat logger for testing purposes.
|
|
||||||
Implements the minimal interface expected by StatLoggerManager.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, vllm_config, engine_idx):
|
|
||||||
self.vllm_config = vllm_config
|
|
||||||
self.engine_idx = engine_idx
|
|
||||||
self.recorded = []
|
|
||||||
self.logged = False
|
|
||||||
self.engine_initialized = False
|
|
||||||
|
|
||||||
def record(self, scheduler_stats, iteration_stats, engine_idx):
|
|
||||||
self.recorded.append((scheduler_stats, iteration_stats, engine_idx))
|
|
||||||
|
|
||||||
def log(self):
|
|
||||||
self.logged = True
|
|
||||||
|
|
||||||
def log_engine_initialized(self):
|
|
||||||
self.engine_initialized = True
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def log_stats_enabled_engine_args():
|
def log_stats_enabled_engine_args():
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -40,7 +40,11 @@ from vllm.v1.engine.output_processor import OutputProcessor, RequestOutputCollec
|
|||||||
from vllm.v1.engine.parallel_sampling import ParentRequest
|
from vllm.v1.engine.parallel_sampling import ParentRequest
|
||||||
from vllm.v1.engine.processor import Processor
|
from vllm.v1.engine.processor import Processor
|
||||||
from vllm.v1.executor.abstract import Executor
|
from vllm.v1.executor.abstract import Executor
|
||||||
from vllm.v1.metrics.loggers import StatLoggerFactory, StatLoggerManager
|
from vllm.v1.metrics.loggers import (
|
||||||
|
StatLoggerFactory,
|
||||||
|
StatLoggerManager,
|
||||||
|
load_stat_logger_plugin_factories,
|
||||||
|
)
|
||||||
from vllm.v1.metrics.prometheus import shutdown_prometheus
|
from vllm.v1.metrics.prometheus import shutdown_prometheus
|
||||||
from vllm.v1.metrics.stats import IterationStats
|
from vllm.v1.metrics.stats import IterationStats
|
||||||
|
|
||||||
@ -100,11 +104,16 @@ class AsyncLLM(EngineClient):
|
|||||||
self.observability_config = vllm_config.observability_config
|
self.observability_config = vllm_config.observability_config
|
||||||
self.log_requests = log_requests
|
self.log_requests = log_requests
|
||||||
|
|
||||||
self.log_stats = log_stats or (stat_loggers is not None)
|
custom_stat_loggers = list(stat_loggers or [])
|
||||||
if not log_stats and stat_loggers is not None:
|
custom_stat_loggers.extend(load_stat_logger_plugin_factories())
|
||||||
|
|
||||||
|
has_custom_loggers = bool(custom_stat_loggers)
|
||||||
|
self.log_stats = log_stats or has_custom_loggers
|
||||||
|
if not log_stats and has_custom_loggers:
|
||||||
logger.info(
|
logger.info(
|
||||||
"AsyncLLM created with log_stats=False and non-empty custom "
|
"AsyncLLM created with log_stats=False, "
|
||||||
"logger list; enabling logging without default stat loggers"
|
"but custom stat loggers were found; "
|
||||||
|
"enabling logging without default stat loggers."
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.model_config.skip_tokenizer_init:
|
if self.model_config.skip_tokenizer_init:
|
||||||
@ -144,7 +153,7 @@ class AsyncLLM(EngineClient):
|
|||||||
self.logger_manager = StatLoggerManager(
|
self.logger_manager = StatLoggerManager(
|
||||||
vllm_config=vllm_config,
|
vllm_config=vllm_config,
|
||||||
engine_idxs=self.engine_core.engine_ranks_managed,
|
engine_idxs=self.engine_core.engine_ranks_managed,
|
||||||
custom_stat_loggers=stat_loggers,
|
custom_stat_loggers=custom_stat_loggers,
|
||||||
enable_default_loggers=log_stats,
|
enable_default_loggers=log_stats,
|
||||||
client_count=client_count,
|
client_count=client_count,
|
||||||
aggregate_engine_logging=aggregate_engine_logging,
|
aggregate_engine_logging=aggregate_engine_logging,
|
||||||
|
|||||||
@ -12,6 +12,7 @@ from prometheus_client import Counter, Gauge, Histogram
|
|||||||
from vllm.config import SupportsMetricsInfo, VllmConfig
|
from vllm.config import SupportsMetricsInfo, VllmConfig
|
||||||
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorLogging
|
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorLogging
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
from vllm.plugins import load_plugins_by_group
|
||||||
from vllm.v1.engine import FinishReason
|
from vllm.v1.engine import FinishReason
|
||||||
from vllm.v1.metrics.prometheus import unregister_vllm_metrics
|
from vllm.v1.metrics.prometheus import unregister_vllm_metrics
|
||||||
from vllm.v1.metrics.stats import (
|
from vllm.v1.metrics.stats import (
|
||||||
@ -56,6 +57,23 @@ class StatLoggerBase(ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def load_stat_logger_plugin_factories() -> list[StatLoggerFactory]:
|
||||||
|
factories: list[StatLoggerFactory] = []
|
||||||
|
|
||||||
|
for name, plugin_class in load_plugins_by_group("vllm.stat_logger_plugins").items():
|
||||||
|
if not isinstance(plugin_class, type) or not issubclass(
|
||||||
|
plugin_class, StatLoggerBase
|
||||||
|
):
|
||||||
|
raise TypeError(
|
||||||
|
f"Stat logger plugin {name!r} must be a subclass of "
|
||||||
|
f"StatLoggerBase (got {plugin_class!r})."
|
||||||
|
)
|
||||||
|
|
||||||
|
factories.append(plugin_class)
|
||||||
|
|
||||||
|
return factories
|
||||||
|
|
||||||
|
|
||||||
class AggregateStatLoggerBase(StatLoggerBase):
|
class AggregateStatLoggerBase(StatLoggerBase):
|
||||||
"""Abstract base class for loggers that
|
"""Abstract base class for loggers that
|
||||||
aggregate across multiple DP engines."""
|
aggregate across multiple DP engines."""
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user