mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 02:25:01 +08:00
[Core] Support custom executor qualname (#23314)
Signed-off-by: 22quinn <33176974+22quinn@users.noreply.github.com>
This commit is contained in:
parent
5368f76855
commit
480bdf5a7b
@ -244,6 +244,7 @@ steps:
|
||||
- pytest -v -s v1/core
|
||||
- pytest -v -s v1/engine
|
||||
- pytest -v -s v1/entrypoints
|
||||
- pytest -v -s v1/executor
|
||||
- pytest -v -s v1/sample
|
||||
- pytest -v -s v1/logits_processors
|
||||
- pytest -v -s v1/worker
|
||||
|
||||
0
tests/v1/executor/__init__.py
Normal file
0
tests/v1/executor/__init__.py
Normal file
116
tests/v1/executor/test_executor.py
Normal file
116
tests/v1/executor/test_executor.py
Normal file
@ -0,0 +1,116 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.v1.engine.async_llm import AsyncLLM
|
||||
from vllm.v1.engine.llm_engine import LLMEngine
|
||||
from vllm.v1.executor.multiproc_executor import MultiprocExecutor
|
||||
|
||||
|
||||
class Mock:
|
||||
...
|
||||
|
||||
|
||||
class CustomMultiprocExecutor(MultiprocExecutor):
|
||||
|
||||
def collective_rpc(self,
|
||||
method: Union[str, Callable],
|
||||
timeout: Optional[float] = None,
|
||||
args: tuple = (),
|
||||
kwargs: Optional[dict] = None,
|
||||
non_block: bool = False,
|
||||
unique_reply_rank: Optional[int] = None) -> list[Any]:
|
||||
# Drop marker to show that this was ran
|
||||
with open(".marker", "w"):
|
||||
...
|
||||
return super().collective_rpc(method, timeout, args, kwargs)
|
||||
|
||||
|
||||
CustomMultiprocExecutorAsync = CustomMultiprocExecutor
|
||||
MODEL = "Qwen/Qwen3-0.6B"
|
||||
|
||||
|
||||
def test_custom_executor_type_checking():
|
||||
with pytest.raises(ValueError):
|
||||
engine_args = EngineArgs(
|
||||
model=MODEL,
|
||||
gpu_memory_utilization=0.2,
|
||||
max_model_len=8192,
|
||||
distributed_executor_backend=Mock,
|
||||
)
|
||||
LLMEngine.from_engine_args(engine_args)
|
||||
with pytest.raises(ValueError):
|
||||
engine_args = AsyncEngineArgs(model=MODEL,
|
||||
gpu_memory_utilization=0.2,
|
||||
max_model_len=8192,
|
||||
distributed_executor_backend=Mock)
|
||||
AsyncLLM.from_engine_args(engine_args)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("distributed_executor_backend", [
|
||||
CustomMultiprocExecutor,
|
||||
"tests.v1.executor.test_executor.CustomMultiprocExecutor"
|
||||
])
|
||||
def test_custom_executor(distributed_executor_backend, tmp_path):
|
||||
cwd = os.path.abspath(".")
|
||||
os.chdir(tmp_path)
|
||||
try:
|
||||
assert not os.path.exists(".marker")
|
||||
|
||||
engine_args = EngineArgs(
|
||||
model=MODEL,
|
||||
gpu_memory_utilization=0.2,
|
||||
max_model_len=8192,
|
||||
distributed_executor_backend=distributed_executor_backend,
|
||||
enforce_eager=True, # reduce test time
|
||||
)
|
||||
engine = LLMEngine.from_engine_args(engine_args)
|
||||
sampling_params = SamplingParams(max_tokens=1)
|
||||
|
||||
engine.add_request("0", "foo", sampling_params)
|
||||
engine.step()
|
||||
|
||||
assert os.path.exists(".marker")
|
||||
finally:
|
||||
os.chdir(cwd)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("distributed_executor_backend", [
|
||||
CustomMultiprocExecutorAsync,
|
||||
"tests.v1.executor.test_executor.CustomMultiprocExecutorAsync"
|
||||
])
|
||||
def test_custom_executor_async(distributed_executor_backend, tmp_path):
|
||||
cwd = os.path.abspath(".")
|
||||
os.chdir(tmp_path)
|
||||
try:
|
||||
assert not os.path.exists(".marker")
|
||||
|
||||
engine_args = AsyncEngineArgs(
|
||||
model=MODEL,
|
||||
gpu_memory_utilization=0.2,
|
||||
max_model_len=8192,
|
||||
distributed_executor_backend=distributed_executor_backend,
|
||||
enforce_eager=True, # reduce test time
|
||||
)
|
||||
engine = AsyncLLM.from_engine_args(engine_args)
|
||||
sampling_params = SamplingParams(max_tokens=1)
|
||||
|
||||
async def t():
|
||||
stream = engine.generate(request_id="0",
|
||||
prompt="foo",
|
||||
sampling_params=sampling_params)
|
||||
async for x in stream:
|
||||
...
|
||||
|
||||
asyncio.run(t())
|
||||
|
||||
assert os.path.exists(".marker")
|
||||
finally:
|
||||
os.chdir(cwd)
|
||||
@ -143,7 +143,8 @@ class ParallelConfig:
|
||||
placement_group: Optional[PlacementGroup] = None
|
||||
"""ray distributed model workers placement group."""
|
||||
|
||||
distributed_executor_backend: Optional[Union[DistributedExecutorBackend,
|
||||
distributed_executor_backend: Optional[Union[str,
|
||||
DistributedExecutorBackend,
|
||||
type[ExecutorBase]]] = None
|
||||
"""Backend to use for distributed model
|
||||
workers, either "ray" or "mp" (multiprocessing). If the product
|
||||
@ -416,23 +417,22 @@ class ParallelConfig:
|
||||
def use_ray(self) -> bool:
|
||||
return self.distributed_executor_backend == "ray" or (
|
||||
isinstance(self.distributed_executor_backend, type)
|
||||
and self.distributed_executor_backend.uses_ray)
|
||||
and getattr(self.distributed_executor_backend, "uses_ray", False))
|
||||
|
||||
@model_validator(mode='after')
|
||||
def _verify_args(self) -> Self:
|
||||
# Lazy import to avoid circular import
|
||||
from vllm.executor.executor_base import ExecutorBase
|
||||
from vllm.platforms import current_platform
|
||||
if self.distributed_executor_backend not in (
|
||||
"ray", "mp", "uni",
|
||||
"external_launcher", None) and not (isinstance(
|
||||
if self.distributed_executor_backend is not None and not isinstance(
|
||||
self.distributed_executor_backend, str) and not (isinstance(
|
||||
self.distributed_executor_backend, type) and issubclass(
|
||||
self.distributed_executor_backend, ExecutorBase)):
|
||||
raise ValueError(
|
||||
"Unrecognized distributed executor backend "
|
||||
f"{self.distributed_executor_backend}. Supported "
|
||||
"values are 'ray', 'mp' 'uni', 'external_launcher' or"
|
||||
" custom ExecutorBase subclass.")
|
||||
"values are 'ray', 'mp' 'uni', 'external_launcher', "
|
||||
" custom ExecutorBase subclass or its import path.")
|
||||
if self.use_ray:
|
||||
from vllm.executor import ray_utils
|
||||
ray_utils.assert_ray_available()
|
||||
|
||||
@ -290,7 +290,7 @@ class EngineArgs:
|
||||
# is intended for expert use only. The API may change without
|
||||
# notice.
|
||||
distributed_executor_backend: Optional[Union[
|
||||
DistributedExecutorBackend,
|
||||
str, DistributedExecutorBackend,
|
||||
Type[ExecutorBase]]] = ParallelConfig.distributed_executor_backend
|
||||
# number of P/D disaggregation (or other disaggregation) workers
|
||||
pipeline_parallel_size: int = ParallelConfig.pipeline_parallel_size
|
||||
|
||||
@ -13,6 +13,7 @@ from vllm.executor.uniproc_executor import ( # noqa
|
||||
ExecutorWithExternalLauncher as ExecutorWithExternalLauncherV0)
|
||||
from vllm.executor.uniproc_executor import ( # noqa
|
||||
UniProcExecutor as UniProcExecutorV0)
|
||||
from vllm.utils import resolve_obj_by_qualname
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
|
||||
from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput
|
||||
|
||||
@ -50,6 +51,13 @@ class Executor(ExecutorBase):
|
||||
# TODO: make v1 scheduling deterministic
|
||||
# to support external launcher
|
||||
executor_class = ExecutorWithExternalLauncher
|
||||
elif isinstance(distributed_executor_backend, str):
|
||||
executor_class = resolve_obj_by_qualname(
|
||||
distributed_executor_backend)
|
||||
if not issubclass(executor_class, ExecutorBase):
|
||||
raise TypeError(
|
||||
"distributed_executor_backend must be a subclass of "
|
||||
f"ExecutorBase. Got {executor_class}.")
|
||||
else:
|
||||
raise ValueError("Unknown distributed executor backend: "
|
||||
f"{distributed_executor_backend}")
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user