[V1] Refactor get_executor_cls (#11754)

This commit is contained in:
Rui Qiao 2025-01-05 23:59:16 -08:00 committed by GitHub
parent f8fcca100b
commit 022c5c6944
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 26 additions and 46 deletions

View File

@ -8,8 +8,8 @@ from vllm import SamplingParams
from vllm.engine.arg_utils import EngineArgs from vllm.engine.arg_utils import EngineArgs
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.v1.engine import EngineCoreRequest from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.async_llm import AsyncLLM
from vllm.v1.engine.core import EngineCore from vllm.v1.engine.core import EngineCore
from vllm.v1.executor.abstract import Executor
if not current_platform.is_cuda(): if not current_platform.is_cuda():
pytest.skip(reason="V1 currently only supported on CUDA.", pytest.skip(reason="V1 currently only supported on CUDA.",
@ -43,7 +43,7 @@ def test_engine_core(monkeypatch):
"""Setup the EngineCore.""" """Setup the EngineCore."""
engine_args = EngineArgs(model=MODEL_NAME) engine_args = EngineArgs(model=MODEL_NAME)
vllm_config = engine_args.create_engine_config() vllm_config = engine_args.create_engine_config()
executor_class = AsyncLLM._get_executor_cls(vllm_config) executor_class = Executor.get_class(vllm_config)
engine_core = EngineCore(vllm_config=vllm_config, engine_core = EngineCore(vllm_config=vllm_config,
executor_class=executor_class) executor_class=executor_class)
@ -149,7 +149,7 @@ def test_engine_core_advanced_sampling(monkeypatch):
"""Setup the EngineCore.""" """Setup the EngineCore."""
engine_args = EngineArgs(model=MODEL_NAME) engine_args = EngineArgs(model=MODEL_NAME)
vllm_config = engine_args.create_engine_config() vllm_config = engine_args.create_engine_config()
executor_class = AsyncLLM._get_executor_cls(vllm_config) executor_class = Executor.get_class(vllm_config)
engine_core = EngineCore(vllm_config=vllm_config, engine_core = EngineCore(vllm_config=vllm_config,
executor_class=executor_class) executor_class=executor_class)

View File

@ -11,8 +11,8 @@ from vllm.engine.arg_utils import EngineArgs
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.usage.usage_lib import UsageContext from vllm.usage.usage_lib import UsageContext
from vllm.v1.engine import EngineCoreRequest from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.async_llm import AsyncLLM
from vllm.v1.engine.core_client import EngineCoreClient from vllm.v1.engine.core_client import EngineCoreClient
from vllm.v1.executor.abstract import Executor
if not current_platform.is_cuda(): if not current_platform.is_cuda():
pytest.skip(reason="V1 currently only supported on CUDA.", pytest.skip(reason="V1 currently only supported on CUDA.",
@ -84,7 +84,7 @@ def test_engine_core_client(monkeypatch, multiprocessing_mode: bool):
engine_args = EngineArgs(model=MODEL_NAME, compilation_config=3) engine_args = EngineArgs(model=MODEL_NAME, compilation_config=3)
vllm_config = engine_args.create_engine_config( vllm_config = engine_args.create_engine_config(
UsageContext.UNKNOWN_CONTEXT) UsageContext.UNKNOWN_CONTEXT)
executor_class = AsyncLLM._get_executor_cls(vllm_config) executor_class = Executor.get_class(vllm_config)
client = EngineCoreClient.make_client( client = EngineCoreClient.make_client(
multiprocess_mode=multiprocessing_mode, multiprocess_mode=multiprocessing_mode,
asyncio_mode=False, asyncio_mode=False,
@ -152,7 +152,7 @@ async def test_engine_core_client_asyncio(monkeypatch):
engine_args = EngineArgs(model=MODEL_NAME) engine_args = EngineArgs(model=MODEL_NAME)
vllm_config = engine_args.create_engine_config( vllm_config = engine_args.create_engine_config(
usage_context=UsageContext.UNKNOWN_CONTEXT) usage_context=UsageContext.UNKNOWN_CONTEXT)
executor_class = AsyncLLM._get_executor_cls(vllm_config) executor_class = Executor.get_class(vllm_config)
client = EngineCoreClient.make_client( client = EngineCoreClient.make_client(
multiprocess_mode=True, multiprocess_mode=True,
asyncio_mode=True, asyncio_mode=True,

View File

@ -22,7 +22,6 @@ from vllm.v1.engine.core_client import EngineCoreClient
from vllm.v1.engine.detokenizer import Detokenizer from vllm.v1.engine.detokenizer import Detokenizer
from vllm.v1.engine.processor import Processor from vllm.v1.engine.processor import Processor
from vllm.v1.executor.abstract import Executor from vllm.v1.executor.abstract import Executor
from vllm.v1.executor.ray_utils import initialize_ray_cluster
logger = init_logger(__name__) logger = init_logger(__name__)
@ -105,7 +104,7 @@ class AsyncLLM(EngineClient):
else: else:
vllm_config = engine_config vllm_config = engine_config
executor_class = cls._get_executor_cls(vllm_config) executor_class = Executor.get_class(vllm_config)
# Create the AsyncLLM. # Create the AsyncLLM.
return cls( return cls(
@ -127,24 +126,6 @@ class AsyncLLM(EngineClient):
if handler := getattr(self, "output_handler", None): if handler := getattr(self, "output_handler", None):
handler.cancel() handler.cancel()
@classmethod
def _get_executor_cls(cls, vllm_config: VllmConfig) -> Type[Executor]:
executor_class: Type[Executor]
distributed_executor_backend = (
vllm_config.parallel_config.distributed_executor_backend)
if distributed_executor_backend == "ray":
initialize_ray_cluster(vllm_config.parallel_config)
from vllm.v1.executor.ray_executor import RayExecutor
executor_class = RayExecutor
elif distributed_executor_backend == "mp":
from vllm.v1.executor.multiproc_executor import MultiprocExecutor
executor_class = MultiprocExecutor
else:
assert (distributed_executor_backend is None)
from vllm.v1.executor.uniproc_executor import UniprocExecutor
executor_class = UniprocExecutor
return executor_class
async def add_request( async def add_request(
self, self,
request_id: str, request_id: str,

View File

@ -89,7 +89,7 @@ class LLMEngine:
# Create the engine configs. # Create the engine configs.
vllm_config = engine_args.create_engine_config(usage_context) vllm_config = engine_args.create_engine_config(usage_context)
executor_class = cls._get_executor_cls(vllm_config) executor_class = Executor.get_class(vllm_config)
if VLLM_ENABLE_V1_MULTIPROCESSING: if VLLM_ENABLE_V1_MULTIPROCESSING:
logger.debug("Enabling multiprocessing for LLMEngine.") logger.debug("Enabling multiprocessing for LLMEngine.")
@ -103,24 +103,6 @@ class LLMEngine:
stat_loggers=stat_loggers, stat_loggers=stat_loggers,
multiprocess_mode=enable_multiprocessing) multiprocess_mode=enable_multiprocessing)
@classmethod
def _get_executor_cls(cls, vllm_config: VllmConfig) -> Type[Executor]:
executor_class: Type[Executor]
distributed_executor_backend = (
vllm_config.parallel_config.distributed_executor_backend)
if distributed_executor_backend == "ray":
from vllm.v1.executor.ray_executor import RayExecutor
executor_class = RayExecutor
elif distributed_executor_backend == "mp":
from vllm.v1.executor.multiproc_executor import MultiprocExecutor
executor_class = MultiprocExecutor
else:
assert (distributed_executor_backend is None)
from vllm.v1.executor.uniproc_executor import UniprocExecutor
executor_class = UniprocExecutor
return executor_class
def get_num_unfinished_requests(self) -> int: def get_num_unfinished_requests(self) -> int:
return self.detokenizer.get_num_unfinished_requests() return self.detokenizer.get_num_unfinished_requests()

View File

@ -1,5 +1,5 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Tuple from typing import Tuple, Type
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.outputs import ModelRunnerOutput
@ -8,6 +8,23 @@ from vllm.v1.outputs import ModelRunnerOutput
class Executor(ABC): class Executor(ABC):
"""Abstract class for executors.""" """Abstract class for executors."""
@staticmethod
def get_class(vllm_config: VllmConfig) -> Type["Executor"]:
executor_class: Type[Executor]
distributed_executor_backend = (
vllm_config.parallel_config.distributed_executor_backend)
if distributed_executor_backend == "ray":
from vllm.v1.executor.ray_executor import RayExecutor
executor_class = RayExecutor
elif distributed_executor_backend == "mp":
from vllm.v1.executor.multiproc_executor import MultiprocExecutor
executor_class = MultiprocExecutor
else:
assert (distributed_executor_backend is None)
from vllm.v1.executor.uniproc_executor import UniprocExecutor
executor_class = UniprocExecutor
return executor_class
@abstractmethod @abstractmethod
def __init__(self, vllm_config: VllmConfig) -> None: def __init__(self, vllm_config: VllmConfig) -> None:
raise NotImplementedError raise NotImplementedError