mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-14 06:34:58 +08:00
[V1] Refactor get_executor_cls (#11754)
This commit is contained in:
parent
f8fcca100b
commit
022c5c6944
@ -8,8 +8,8 @@ from vllm import SamplingParams
|
|||||||
from vllm.engine.arg_utils import EngineArgs
|
from vllm.engine.arg_utils import EngineArgs
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.v1.engine import EngineCoreRequest
|
from vllm.v1.engine import EngineCoreRequest
|
||||||
from vllm.v1.engine.async_llm import AsyncLLM
|
|
||||||
from vllm.v1.engine.core import EngineCore
|
from vllm.v1.engine.core import EngineCore
|
||||||
|
from vllm.v1.executor.abstract import Executor
|
||||||
|
|
||||||
if not current_platform.is_cuda():
|
if not current_platform.is_cuda():
|
||||||
pytest.skip(reason="V1 currently only supported on CUDA.",
|
pytest.skip(reason="V1 currently only supported on CUDA.",
|
||||||
@ -43,7 +43,7 @@ def test_engine_core(monkeypatch):
|
|||||||
"""Setup the EngineCore."""
|
"""Setup the EngineCore."""
|
||||||
engine_args = EngineArgs(model=MODEL_NAME)
|
engine_args = EngineArgs(model=MODEL_NAME)
|
||||||
vllm_config = engine_args.create_engine_config()
|
vllm_config = engine_args.create_engine_config()
|
||||||
executor_class = AsyncLLM._get_executor_cls(vllm_config)
|
executor_class = Executor.get_class(vllm_config)
|
||||||
|
|
||||||
engine_core = EngineCore(vllm_config=vllm_config,
|
engine_core = EngineCore(vllm_config=vllm_config,
|
||||||
executor_class=executor_class)
|
executor_class=executor_class)
|
||||||
@ -149,7 +149,7 @@ def test_engine_core_advanced_sampling(monkeypatch):
|
|||||||
"""Setup the EngineCore."""
|
"""Setup the EngineCore."""
|
||||||
engine_args = EngineArgs(model=MODEL_NAME)
|
engine_args = EngineArgs(model=MODEL_NAME)
|
||||||
vllm_config = engine_args.create_engine_config()
|
vllm_config = engine_args.create_engine_config()
|
||||||
executor_class = AsyncLLM._get_executor_cls(vllm_config)
|
executor_class = Executor.get_class(vllm_config)
|
||||||
|
|
||||||
engine_core = EngineCore(vllm_config=vllm_config,
|
engine_core = EngineCore(vllm_config=vllm_config,
|
||||||
executor_class=executor_class)
|
executor_class=executor_class)
|
||||||
|
|||||||
@ -11,8 +11,8 @@ from vllm.engine.arg_utils import EngineArgs
|
|||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.usage.usage_lib import UsageContext
|
from vllm.usage.usage_lib import UsageContext
|
||||||
from vllm.v1.engine import EngineCoreRequest
|
from vllm.v1.engine import EngineCoreRequest
|
||||||
from vllm.v1.engine.async_llm import AsyncLLM
|
|
||||||
from vllm.v1.engine.core_client import EngineCoreClient
|
from vllm.v1.engine.core_client import EngineCoreClient
|
||||||
|
from vllm.v1.executor.abstract import Executor
|
||||||
|
|
||||||
if not current_platform.is_cuda():
|
if not current_platform.is_cuda():
|
||||||
pytest.skip(reason="V1 currently only supported on CUDA.",
|
pytest.skip(reason="V1 currently only supported on CUDA.",
|
||||||
@ -84,7 +84,7 @@ def test_engine_core_client(monkeypatch, multiprocessing_mode: bool):
|
|||||||
engine_args = EngineArgs(model=MODEL_NAME, compilation_config=3)
|
engine_args = EngineArgs(model=MODEL_NAME, compilation_config=3)
|
||||||
vllm_config = engine_args.create_engine_config(
|
vllm_config = engine_args.create_engine_config(
|
||||||
UsageContext.UNKNOWN_CONTEXT)
|
UsageContext.UNKNOWN_CONTEXT)
|
||||||
executor_class = AsyncLLM._get_executor_cls(vllm_config)
|
executor_class = Executor.get_class(vllm_config)
|
||||||
client = EngineCoreClient.make_client(
|
client = EngineCoreClient.make_client(
|
||||||
multiprocess_mode=multiprocessing_mode,
|
multiprocess_mode=multiprocessing_mode,
|
||||||
asyncio_mode=False,
|
asyncio_mode=False,
|
||||||
@ -152,7 +152,7 @@ async def test_engine_core_client_asyncio(monkeypatch):
|
|||||||
engine_args = EngineArgs(model=MODEL_NAME)
|
engine_args = EngineArgs(model=MODEL_NAME)
|
||||||
vllm_config = engine_args.create_engine_config(
|
vllm_config = engine_args.create_engine_config(
|
||||||
usage_context=UsageContext.UNKNOWN_CONTEXT)
|
usage_context=UsageContext.UNKNOWN_CONTEXT)
|
||||||
executor_class = AsyncLLM._get_executor_cls(vllm_config)
|
executor_class = Executor.get_class(vllm_config)
|
||||||
client = EngineCoreClient.make_client(
|
client = EngineCoreClient.make_client(
|
||||||
multiprocess_mode=True,
|
multiprocess_mode=True,
|
||||||
asyncio_mode=True,
|
asyncio_mode=True,
|
||||||
|
|||||||
@ -22,7 +22,6 @@ from vllm.v1.engine.core_client import EngineCoreClient
|
|||||||
from vllm.v1.engine.detokenizer import Detokenizer
|
from vllm.v1.engine.detokenizer import Detokenizer
|
||||||
from vllm.v1.engine.processor import Processor
|
from vllm.v1.engine.processor import Processor
|
||||||
from vllm.v1.executor.abstract import Executor
|
from vllm.v1.executor.abstract import Executor
|
||||||
from vllm.v1.executor.ray_utils import initialize_ray_cluster
|
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -105,7 +104,7 @@ class AsyncLLM(EngineClient):
|
|||||||
else:
|
else:
|
||||||
vllm_config = engine_config
|
vllm_config = engine_config
|
||||||
|
|
||||||
executor_class = cls._get_executor_cls(vllm_config)
|
executor_class = Executor.get_class(vllm_config)
|
||||||
|
|
||||||
# Create the AsyncLLM.
|
# Create the AsyncLLM.
|
||||||
return cls(
|
return cls(
|
||||||
@ -127,24 +126,6 @@ class AsyncLLM(EngineClient):
|
|||||||
if handler := getattr(self, "output_handler", None):
|
if handler := getattr(self, "output_handler", None):
|
||||||
handler.cancel()
|
handler.cancel()
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _get_executor_cls(cls, vllm_config: VllmConfig) -> Type[Executor]:
|
|
||||||
executor_class: Type[Executor]
|
|
||||||
distributed_executor_backend = (
|
|
||||||
vllm_config.parallel_config.distributed_executor_backend)
|
|
||||||
if distributed_executor_backend == "ray":
|
|
||||||
initialize_ray_cluster(vllm_config.parallel_config)
|
|
||||||
from vllm.v1.executor.ray_executor import RayExecutor
|
|
||||||
executor_class = RayExecutor
|
|
||||||
elif distributed_executor_backend == "mp":
|
|
||||||
from vllm.v1.executor.multiproc_executor import MultiprocExecutor
|
|
||||||
executor_class = MultiprocExecutor
|
|
||||||
else:
|
|
||||||
assert (distributed_executor_backend is None)
|
|
||||||
from vllm.v1.executor.uniproc_executor import UniprocExecutor
|
|
||||||
executor_class = UniprocExecutor
|
|
||||||
return executor_class
|
|
||||||
|
|
||||||
async def add_request(
|
async def add_request(
|
||||||
self,
|
self,
|
||||||
request_id: str,
|
request_id: str,
|
||||||
|
|||||||
@ -89,7 +89,7 @@ class LLMEngine:
|
|||||||
|
|
||||||
# Create the engine configs.
|
# Create the engine configs.
|
||||||
vllm_config = engine_args.create_engine_config(usage_context)
|
vllm_config = engine_args.create_engine_config(usage_context)
|
||||||
executor_class = cls._get_executor_cls(vllm_config)
|
executor_class = Executor.get_class(vllm_config)
|
||||||
|
|
||||||
if VLLM_ENABLE_V1_MULTIPROCESSING:
|
if VLLM_ENABLE_V1_MULTIPROCESSING:
|
||||||
logger.debug("Enabling multiprocessing for LLMEngine.")
|
logger.debug("Enabling multiprocessing for LLMEngine.")
|
||||||
@ -103,24 +103,6 @@ class LLMEngine:
|
|||||||
stat_loggers=stat_loggers,
|
stat_loggers=stat_loggers,
|
||||||
multiprocess_mode=enable_multiprocessing)
|
multiprocess_mode=enable_multiprocessing)
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _get_executor_cls(cls, vllm_config: VllmConfig) -> Type[Executor]:
|
|
||||||
executor_class: Type[Executor]
|
|
||||||
distributed_executor_backend = (
|
|
||||||
vllm_config.parallel_config.distributed_executor_backend)
|
|
||||||
if distributed_executor_backend == "ray":
|
|
||||||
from vllm.v1.executor.ray_executor import RayExecutor
|
|
||||||
executor_class = RayExecutor
|
|
||||||
elif distributed_executor_backend == "mp":
|
|
||||||
from vllm.v1.executor.multiproc_executor import MultiprocExecutor
|
|
||||||
executor_class = MultiprocExecutor
|
|
||||||
else:
|
|
||||||
assert (distributed_executor_backend is None)
|
|
||||||
from vllm.v1.executor.uniproc_executor import UniprocExecutor
|
|
||||||
executor_class = UniprocExecutor
|
|
||||||
|
|
||||||
return executor_class
|
|
||||||
|
|
||||||
def get_num_unfinished_requests(self) -> int:
|
def get_num_unfinished_requests(self) -> int:
|
||||||
return self.detokenizer.get_num_unfinished_requests()
|
return self.detokenizer.get_num_unfinished_requests()
|
||||||
|
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Tuple
|
from typing import Tuple, Type
|
||||||
|
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.v1.outputs import ModelRunnerOutput
|
from vllm.v1.outputs import ModelRunnerOutput
|
||||||
@ -8,6 +8,23 @@ from vllm.v1.outputs import ModelRunnerOutput
|
|||||||
class Executor(ABC):
|
class Executor(ABC):
|
||||||
"""Abstract class for executors."""
|
"""Abstract class for executors."""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_class(vllm_config: VllmConfig) -> Type["Executor"]:
|
||||||
|
executor_class: Type[Executor]
|
||||||
|
distributed_executor_backend = (
|
||||||
|
vllm_config.parallel_config.distributed_executor_backend)
|
||||||
|
if distributed_executor_backend == "ray":
|
||||||
|
from vllm.v1.executor.ray_executor import RayExecutor
|
||||||
|
executor_class = RayExecutor
|
||||||
|
elif distributed_executor_backend == "mp":
|
||||||
|
from vllm.v1.executor.multiproc_executor import MultiprocExecutor
|
||||||
|
executor_class = MultiprocExecutor
|
||||||
|
else:
|
||||||
|
assert (distributed_executor_backend is None)
|
||||||
|
from vllm.v1.executor.uniproc_executor import UniprocExecutor
|
||||||
|
executor_class = UniprocExecutor
|
||||||
|
return executor_class
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def __init__(self, vllm_config: VllmConfig) -> None:
|
def __init__(self, vllm_config: VllmConfig) -> None:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user