[Core] Support custom executor qualname (#23314)

Signed-off-by: 22quinn <33176974+22quinn@users.noreply.github.com>
This commit is contained in:
22quinn 2025-08-21 18:40:54 -07:00 committed by GitHub
parent 5368f76855
commit 480bdf5a7b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 133 additions and 8 deletions

View File

@ -244,6 +244,7 @@ steps:
- pytest -v -s v1/core - pytest -v -s v1/core
- pytest -v -s v1/engine - pytest -v -s v1/engine
- pytest -v -s v1/entrypoints - pytest -v -s v1/entrypoints
- pytest -v -s v1/executor
- pytest -v -s v1/sample - pytest -v -s v1/sample
- pytest -v -s v1/logits_processors - pytest -v -s v1/logits_processors
- pytest -v -s v1/worker - pytest -v -s v1/worker

View File

View File

@ -0,0 +1,116 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import os
from typing import Any, Callable, Optional, Union
import pytest
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
from vllm.sampling_params import SamplingParams
from vllm.v1.engine.async_llm import AsyncLLM
from vllm.v1.engine.llm_engine import LLMEngine
from vllm.v1.executor.multiproc_executor import MultiprocExecutor
class Mock:
...
class CustomMultiprocExecutor(MultiprocExecutor):
def collective_rpc(self,
method: Union[str, Callable],
timeout: Optional[float] = None,
args: tuple = (),
kwargs: Optional[dict] = None,
non_block: bool = False,
unique_reply_rank: Optional[int] = None) -> list[Any]:
# Drop marker to show that this was ran
with open(".marker", "w"):
...
return super().collective_rpc(method, timeout, args, kwargs)
CustomMultiprocExecutorAsync = CustomMultiprocExecutor
MODEL = "Qwen/Qwen3-0.6B"
def test_custom_executor_type_checking():
with pytest.raises(ValueError):
engine_args = EngineArgs(
model=MODEL,
gpu_memory_utilization=0.2,
max_model_len=8192,
distributed_executor_backend=Mock,
)
LLMEngine.from_engine_args(engine_args)
with pytest.raises(ValueError):
engine_args = AsyncEngineArgs(model=MODEL,
gpu_memory_utilization=0.2,
max_model_len=8192,
distributed_executor_backend=Mock)
AsyncLLM.from_engine_args(engine_args)
@pytest.mark.parametrize("distributed_executor_backend", [
CustomMultiprocExecutor,
"tests.v1.executor.test_executor.CustomMultiprocExecutor"
])
def test_custom_executor(distributed_executor_backend, tmp_path):
cwd = os.path.abspath(".")
os.chdir(tmp_path)
try:
assert not os.path.exists(".marker")
engine_args = EngineArgs(
model=MODEL,
gpu_memory_utilization=0.2,
max_model_len=8192,
distributed_executor_backend=distributed_executor_backend,
enforce_eager=True, # reduce test time
)
engine = LLMEngine.from_engine_args(engine_args)
sampling_params = SamplingParams(max_tokens=1)
engine.add_request("0", "foo", sampling_params)
engine.step()
assert os.path.exists(".marker")
finally:
os.chdir(cwd)
@pytest.mark.parametrize("distributed_executor_backend", [
CustomMultiprocExecutorAsync,
"tests.v1.executor.test_executor.CustomMultiprocExecutorAsync"
])
def test_custom_executor_async(distributed_executor_backend, tmp_path):
cwd = os.path.abspath(".")
os.chdir(tmp_path)
try:
assert not os.path.exists(".marker")
engine_args = AsyncEngineArgs(
model=MODEL,
gpu_memory_utilization=0.2,
max_model_len=8192,
distributed_executor_backend=distributed_executor_backend,
enforce_eager=True, # reduce test time
)
engine = AsyncLLM.from_engine_args(engine_args)
sampling_params = SamplingParams(max_tokens=1)
async def t():
stream = engine.generate(request_id="0",
prompt="foo",
sampling_params=sampling_params)
async for x in stream:
...
asyncio.run(t())
assert os.path.exists(".marker")
finally:
os.chdir(cwd)

View File

@ -143,7 +143,8 @@ class ParallelConfig:
placement_group: Optional[PlacementGroup] = None placement_group: Optional[PlacementGroup] = None
"""ray distributed model workers placement group.""" """ray distributed model workers placement group."""
distributed_executor_backend: Optional[Union[DistributedExecutorBackend, distributed_executor_backend: Optional[Union[str,
DistributedExecutorBackend,
type[ExecutorBase]]] = None type[ExecutorBase]]] = None
"""Backend to use for distributed model """Backend to use for distributed model
workers, either "ray" or "mp" (multiprocessing). If the product workers, either "ray" or "mp" (multiprocessing). If the product
@ -416,23 +417,22 @@ class ParallelConfig:
def use_ray(self) -> bool: def use_ray(self) -> bool:
return self.distributed_executor_backend == "ray" or ( return self.distributed_executor_backend == "ray" or (
isinstance(self.distributed_executor_backend, type) isinstance(self.distributed_executor_backend, type)
and self.distributed_executor_backend.uses_ray) and getattr(self.distributed_executor_backend, "uses_ray", False))
@model_validator(mode='after') @model_validator(mode='after')
def _verify_args(self) -> Self: def _verify_args(self) -> Self:
# Lazy import to avoid circular import # Lazy import to avoid circular import
from vllm.executor.executor_base import ExecutorBase from vllm.executor.executor_base import ExecutorBase
from vllm.platforms import current_platform from vllm.platforms import current_platform
if self.distributed_executor_backend not in ( if self.distributed_executor_backend is not None and not isinstance(
"ray", "mp", "uni", self.distributed_executor_backend, str) and not (isinstance(
"external_launcher", None) and not (isinstance(
self.distributed_executor_backend, type) and issubclass( self.distributed_executor_backend, type) and issubclass(
self.distributed_executor_backend, ExecutorBase)): self.distributed_executor_backend, ExecutorBase)):
raise ValueError( raise ValueError(
"Unrecognized distributed executor backend " "Unrecognized distributed executor backend "
f"{self.distributed_executor_backend}. Supported " f"{self.distributed_executor_backend}. Supported "
"values are 'ray', 'mp' 'uni', 'external_launcher' or" "values are 'ray', 'mp' 'uni', 'external_launcher', "
" custom ExecutorBase subclass.") " custom ExecutorBase subclass or its import path.")
if self.use_ray: if self.use_ray:
from vllm.executor import ray_utils from vllm.executor import ray_utils
ray_utils.assert_ray_available() ray_utils.assert_ray_available()

View File

@ -290,7 +290,7 @@ class EngineArgs:
# is intended for expert use only. The API may change without # is intended for expert use only. The API may change without
# notice. # notice.
distributed_executor_backend: Optional[Union[ distributed_executor_backend: Optional[Union[
DistributedExecutorBackend, str, DistributedExecutorBackend,
Type[ExecutorBase]]] = ParallelConfig.distributed_executor_backend Type[ExecutorBase]]] = ParallelConfig.distributed_executor_backend
# number of P/D disaggregation (or other disaggregation) workers # number of P/D disaggregation (or other disaggregation) workers
pipeline_parallel_size: int = ParallelConfig.pipeline_parallel_size pipeline_parallel_size: int = ParallelConfig.pipeline_parallel_size

View File

@ -13,6 +13,7 @@ from vllm.executor.uniproc_executor import ( # noqa
ExecutorWithExternalLauncher as ExecutorWithExternalLauncherV0) ExecutorWithExternalLauncher as ExecutorWithExternalLauncherV0)
from vllm.executor.uniproc_executor import ( # noqa from vllm.executor.uniproc_executor import ( # noqa
UniProcExecutor as UniProcExecutorV0) UniProcExecutor as UniProcExecutorV0)
from vllm.utils import resolve_obj_by_qualname
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput
@ -50,6 +51,13 @@ class Executor(ExecutorBase):
# TODO: make v1 scheduling deterministic # TODO: make v1 scheduling deterministic
# to support external launcher # to support external launcher
executor_class = ExecutorWithExternalLauncher executor_class = ExecutorWithExternalLauncher
elif isinstance(distributed_executor_backend, str):
executor_class = resolve_obj_by_qualname(
distributed_executor_backend)
if not issubclass(executor_class, ExecutorBase):
raise TypeError(
"distributed_executor_backend must be a subclass of "
f"ExecutorBase. Got {executor_class}.")
else: else:
raise ValueError("Unknown distributed executor backend: " raise ValueError("Unknown distributed executor backend: "
f"{distributed_executor_backend}") f"{distributed_executor_backend}")