diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 47b3efa6af72..c94e440e5c84 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1490,14 +1490,18 @@ class EngineArgs: and _warn_or_fallback("Engine in background thread")): return False - if (self.pipeline_parallel_size > 1 - and self.distributed_executor_backend - not in (ParallelConfig.distributed_executor_backend, "ray", - "mp", "external_launcher")): - name = "Pipeline Parallelism without Ray distributed executor " \ - "or multiprocessing executor or external launcher" - _raise_or_fallback(feature_name=name, recommend_to_remove=False) - return False + if self.pipeline_parallel_size > 1: + supports_pp = getattr(self.distributed_executor_backend, + 'supports_pp', False) + if not supports_pp and self.distributed_executor_backend not in ( + ParallelConfig.distributed_executor_backend, "ray", "mp", + "external_launcher"): + name = "Pipeline Parallelism without Ray distributed " \ + "executor or multiprocessing executor or external " \ + "launcher" + _raise_or_fallback(feature_name=name, + recommend_to_remove=False) + return False # The platform may be supported on V1, but off by default for now. if not current_platform.default_v1( # noqa: SIM103 diff --git a/vllm/executor/executor_base.py b/vllm/executor/executor_base.py index 97d0d6f08b81..813232cd1928 100644 --- a/vllm/executor/executor_base.py +++ b/vllm/executor/executor_base.py @@ -35,6 +35,7 @@ class ExecutorBase(ABC): """ uses_ray: bool # whether the executor uses Ray for orchestration. + supports_pp: bool = False # whether the executor supports PP def __init__( self, diff --git a/vllm/v1/executor/multiproc_executor.py b/vllm/v1/executor/multiproc_executor.py index d90051c3224f..0db3bcd7fb40 100644 --- a/vllm/v1/executor/multiproc_executor.py +++ b/vllm/v1/executor/multiproc_executor.py @@ -41,6 +41,8 @@ logger = init_logger(__name__) class MultiprocExecutor(Executor): + supports_pp: bool = True + def _init_executor(self) -> None: # Call self.shutdown at exit to clean up # and ensure workers will be terminated. diff --git a/vllm/v1/executor/ray_distributed_executor.py b/vllm/v1/executor/ray_distributed_executor.py index b86ac048f520..c05ad1966d61 100644 --- a/vllm/v1/executor/ray_distributed_executor.py +++ b/vllm/v1/executor/ray_distributed_executor.py @@ -43,6 +43,8 @@ class FutureWrapper(Future): class RayDistributedExecutor(RayDistributedExecutorV0, Executor): """Ray distributed executor using Ray Compiled Graphs.""" + supports_pp: bool = True + def _init_executor(self) -> None: super()._init_executor()